Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def sample_stats_to_xarray(self):
data_warmup = {}
for stat in self.trace.stat_names:
name = rename_key.get(stat, stat)
if name == "tune":
if name in {"tune", "in_warmup"}:
continue
if self.warmup_trace:
data_warmup[name] = np.array(
Expand Down
11 changes: 10 additions & 1 deletion pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,13 @@ def point(self, idx: int) -> dict[str, np.ndarray]:
"""
raise NotImplementedError()

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
def record(
self,
draw: Mapping[str, np.ndarray],
stats: Sequence[Mapping[str, Any]],
*,
in_warmup: bool,
):
"""Record results of a sampling iteration.

Parameters
Expand All @@ -122,6 +128,9 @@ def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, An
Values mapped to variable names
stats: list of dicts
The diagnostic values for each sampler
in_warmup: bool
Whether this draw belongs to the warmup phase. This is a driver-owned
concept and is intended for storage/backends to persist warmup information.
"""
raise NotImplementedError()

Expand Down
28 changes: 24 additions & 4 deletions pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
BlockedStep,
CompoundStep,
StatsBijection,
check_step_emits_tune,
flat_statname,
flatten_steps,
)
Expand Down Expand Up @@ -106,16 +105,26 @@ def __init__(
{sname: stats_dtypes[fname] for fname, sname, is_obj in sstats}
for sstats in stats_bijection._stat_groups
]
if "in_warmup" in stats_dtypes and self.sampler_vars:
# Expose driver-owned warmup marker via the sampler-stats API.
self.sampler_vars[0].setdefault("in_warmup", stats_dtypes["in_warmup"])

self._chain = chain
self._point_fn = point_fn
self._statsbj = stats_bijection
super().__init__()

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
def record(
self,
draw: Mapping[str, np.ndarray],
stats: Sequence[Mapping[str, Any]],
*,
in_warmup: bool,
):
values = self._point_fn(draw)
value_dict = dict(zip(self.varnames, values))
stats_dict = self._statsbj.map(stats)
stats_dict["in_warmup"] = bool(in_warmup)
# Apply pickling to objects stats
for fname in self._statsbj.object_stats.keys():
val_bytes = pickle.dumps(stats_dict[fname])
Expand Down Expand Up @@ -148,6 +157,9 @@ def get_sampler_stats(
self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1
) -> np.ndarray:
slc = slice(burn, None, thin)
if stat_name in {"in_warmup", "tune"}:
# Backwards-friendly alias for users that might try "tune".
return self._get_stats("in_warmup", slc)
# When there's just one sampler, default to remove the sampler dimension
if sampler_idx is None and self._statsbj.n_samplers == 1:
sampler_idx = 0
Expand Down Expand Up @@ -210,8 +222,6 @@ def make_runmeta_and_point_fn(
) -> tuple[mcb.RunMeta, PointFunc]:
variables, point_fn = get_variables_and_point_fn(model, initial_point)

check_step_emits_tune(step)

# In PyMC the sampler stats are grouped by the sampler.
sample_stats = []
steps = flatten_steps(step)
Expand All @@ -235,6 +245,16 @@ def make_runmeta_and_point_fn(
)
sample_stats.append(svar)

# driver owned warmup marker. stored once per draw.
sample_stats.append(
mcb.Variable(
name="in_warmup",
dtype=np.dtype(bool).name,
shape=[],
undefined_ndim=False,
)
)

coordinates = [
mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(np.array(cvals)))
for dname, cvals in model.coords.items()
Expand Down
4 changes: 2 additions & 2 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
new = np.zeros(draws, dtype=dtype)
data[varname] = np.concatenate([old, new])

def record(self, point, sampler_stats=None) -> None:
def record(self, point, sampler_stats=None, *, in_warmup: bool) -> None:
"""Record results of a sampling iteration.

Parameters
Expand Down Expand Up @@ -238,5 +238,5 @@ def point_fun(point):

chain.fn = point_fun
for point in point_list:
chain.record(point)
chain.record(point, in_warmup=False)
return MultiTrace([chain])
8 changes: 7 additions & 1 deletion pymc/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ def buffer(self, group, var_name, value):
buffer[var_name].append(value)

def record(
self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]
self,
draw: Mapping[str, np.ndarray],
stats: Sequence[Mapping[str, Any]],
*,
in_warmup: bool,
) -> bool | None:
"""Record the step method's returned draw and stats.

Expand All @@ -185,6 +189,7 @@ def record(
self.buffer(group="posterior", var_name=var_name, value=var_value)
for var_name, var_value in self.stats_bijection.map(stats).items():
self.buffer(group="sample_stats", var_name=var_name, value=var_value)
self.buffer(group="sample_stats", var_name="in_warmup", value=bool(in_warmup))
self._buffered_draws += 1
if self._buffered_draws == self.draws_until_flush:
self.flush()
Expand Down Expand Up @@ -525,6 +530,7 @@ def init_trace(
stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(
[step] if isinstance(step, BlockedStep) else step.methods
)
stats_dtypes_shapes = {"in_warmup": (bool, [])} | stats_dtypes_shapes
self.init_group_with_empty(
group=self.root.create_group(name="sample_stats", overwrite=True),
var_dtype_and_shape=stats_dtypes_shapes,
Expand Down
10 changes: 9 additions & 1 deletion pymc/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(

self._show_progress = show_progress
self.completed_draws = 0
self.tune = tune
self.total_draws = draws + tune
self.desc = "Sampling chain"
self.chains = chains
Expand All @@ -308,6 +309,7 @@ def _initialize_tasks(self):
draws=0,
total=self.total_draws * self.chains - 1,
chain_idx=0,
in_warmup=self.tune > 0,
sampling_speed=0,
speed_unit="draws/s",
failing=False,
Expand All @@ -323,6 +325,7 @@ def _initialize_tasks(self):
draws=0,
total=self.total_draws - 1,
chain_idx=chain_idx,
in_warmup=self.tune > 0,
sampling_speed=0,
speed_unit="draws/s",
failing=False,
Expand Down Expand Up @@ -381,6 +384,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
self.tasks[chain_idx],
completed=draw,
draws=draw,
in_warmup=tuning,
sampling_speed=speed,
speed_unit=unit,
failing=failing,
Expand All @@ -391,13 +395,17 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
self._progress.update(
self.tasks[chain_idx],
draws=draw + 1 if not self.combined_progress else draw,
in_warmup=False,
failing=failing,
**all_step_stats,
refresh=True,
)

def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))]
columns = [
TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1)),
TextColumn("{task.fields[in_warmup]}", table_column=Column("Warmup", ratio=1)),
]

if self.full_stats:
columns += step_columns
Expand Down
22 changes: 7 additions & 15 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,18 +1043,10 @@ def _sample_return(
else:
traces, length = _choose_chains(traces, 0)
mtrace = MultiTrace(traces)[:length]
# count the number of tune/draw iterations that happened
# ideally via the "tune" statistic, but not all samplers record it!
if "tune" in mtrace.stat_names:
# Get the tune stat directly from chain 0, sampler 0
stat = mtrace._straces[0].get_sampler_stats("tune", sampler_idx=0)
stat = tuple(stat)
n_tune = stat.count(True)
n_draws = stat.count(False)
else:
# these may be wrong when KeyboardInterrupt happened, but they're better than nothing
n_tune = min(tune, len(mtrace))
n_draws = max(0, len(mtrace) - n_tune)
# Count the number of tune/draw iterations that happened.
# The warmup/draw boundary is owned by the sampling driver.
n_tune = min(tune, len(mtrace))
n_draws = max(0, len(mtrace) - n_tune)

if discard_tuned_samples:
mtrace = mtrace[n_tune:]
Expand Down Expand Up @@ -1221,7 +1213,7 @@ def _sample(
try:
for it, stats in enumerate(sampling_gen):
progress_manager.update(
chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune
chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it < tune
)

if not progress_manager.combined_progress or chain == progress_manager.chains - 1:
Expand Down Expand Up @@ -1292,7 +1284,7 @@ def _iter_sample(
step.stop_tuning()

point, stats = step.step(point)
trace.record(point, stats)
trace.record(point, stats, in_warmup=i < tune)
log_warning_stats(stats)

if callback is not None:
Expand Down Expand Up @@ -1405,7 +1397,7 @@ def _mp_sample(
strace = traces[draw.chain]
if not zarr_recording:
# Zarr recording happens in each process
strace.record(draw.point, draw.stats)
strace.record(draw.point, draw.stats, in_warmup=draw.tuning)
log_warning_stats(draw.stats)

if callback is not None:
Expand Down
2 changes: 1 addition & 1 deletion pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _start_loop(self):
raise KeyboardInterrupt()
elif msg[0] == "write_next":
if zarr_recording:
self._zarr_chain.record(point, stats)
self._zarr_chain.record(point, stats, in_warmup=tuning)
self._write_point(point)
is_last = draw + 1 == self._draws + self._tune
self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats))
Expand Down
2 changes: 1 addition & 1 deletion pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _iter_population(
# apply the update to the points and record to the traces
for c, strace in enumerate(traces):
points[c], stats = updates[c]
flushed = strace.record(points[c], stats)
flushed = strace.record(points[c], stats, in_warmup=i < tune)
log_warning_stats(stats)
if flushed and isinstance(strace, ZarrChain):
sampling_state = popstep.request_sampling_state(c)
Expand Down
2 changes: 1 addition & 1 deletion pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _posterior_to_trace(self, chain=0) -> NDArray:
var_samples = np.round(var_samples).astype(var.dtype)
value.append(var_samples.reshape(shape))
size += new_size
strace.record(point=dict(zip(varnames, value)))
strace.record(point=dict(zip(varnames, value)), in_warmup=False)
return strace


Expand Down
14 changes: 4 additions & 10 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def infer_warn_stats_info(
sds[sname] = (dtype, None)
elif sds:
stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()})

# Even when a step method does not emit any stats, downstream components still assume one stats "slot" per step method. represent that with a single empty dict.
if not stats_dtypes:
stats_dtypes.append({})
return stats_dtypes, sds


Expand Down Expand Up @@ -351,16 +355,6 @@ def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]:
return steps


def check_step_emits_tune(step: CompoundStep | BlockedStep):
if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes:
raise TypeError(f"{type(step)} does not emit the required 'tune' stat.")
elif isinstance(step, CompoundStep):
for sstep in step.methods:
if "tune" not in sstep.stats_dtypes_shapes:
raise TypeError(f"{type(sstep)} does not emit the required 'tune' stat.")
return


class StatsBijection:
"""Map between a `list` of stats to `dict` of stats."""

Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
self.iter_count += 1

stats: dict[str, Any] = {
"tune": self.tune,
"diverging": diverging,
"divergences": self.divergences,
"perf_counter_diff": perf_end - perf_start,
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class HamiltonianMC(BaseHMC):
stats_dtypes_shapes = {
"step_size": (np.float64, []),
"n_steps": (np.int64, []),
"tune": (bool, []),
"step_size_bar": (np.float64, []),
"accept": (np.float64, []),
"diverging": (bool, []),
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ class NUTS(BaseHMC):
stats_dtypes_shapes = {
"depth": (np.int64, []),
"step_size": (np.float64, []),
"tune": (bool, []),
"mean_tree_accept": (np.float64, []),
"step_size_bar": (np.float64, []),
"tree_size": (np.float64, []),
Expand Down
Loading
Loading