Skip to content

Conversation

@eclipse1605
Copy link

Description

I tried to make “tuning vs draws” a driver owned concept again. Right now, parts of sampling/postprocessing infer warmup length from a per-step "tune" sampler stat, which can get out of sync (e.g. a step method returning "tune": False everywhere makes PyMC think n_tune == 0, so warmup isn’t discarded and the logs look wrong).

Related Issues

Fixes: #7997
Context: #7776 (progressbar/stat refactor that exposed the mismatch)
Related discussion/attempts: #7730, #7721, #7724, #8014

@ricardoV94
Copy link
Member

@OriolAbril / @aloctavodia does any part of Arviz require the step samples to have a tune flag? Is it enough that we have warmup / posterior distinction, each with their number of draws?

return steps


def check_step_emits_tune(step: CompoundStep | BlockedStep):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this ends up working without tune, remove this function as well

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 21, 2025

Taking a step back, would it make sense for a tune=None mode where the sampler(s) decide how much tune they need? In that case it would make sense for the individual steps to report back whether they're tuning or not.

Even if that's the case, I think it still makes sense to remove this currently useless stat and reintroduce in a separate PR (provided nobody finds a reason why it is actually useful/needed).

CC @aloctavodia, @lucianopaz @aseyboldt

@codecov
Copy link

codecov bot commented Dec 21, 2025

Codecov Report

❌ Patch coverage is 89.65517% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.48%. Comparing base (cadb97a) to head (46a9b3c).
⚠️ Report is 21 commits behind head on main.

Files with missing lines Patch % Lines
pymc/backends/ndarray.py 50.00% 1 Missing ⚠️
pymc/sampling/parallel.py 0.00% 1 Missing ⚠️
pymc/smc/kernels.py 0.00% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #8015      +/-   ##
==========================================
+ Coverage   90.22%   91.48%   +1.26%     
==========================================
  Files         116      116              
  Lines       18972    19059      +87     
==========================================
+ Hits        17117    17437     +320     
+ Misses       1855     1622     -233     
Files with missing lines Coverage Δ
pymc/backends/arviz.py 96.04% <100.00%> (ø)
pymc/backends/base.py 88.26% <100.00%> (-0.44%) ⬇️
pymc/backends/mcbackend.py 99.28% <100.00%> (+0.02%) ⬆️
pymc/backends/zarr.py 93.85% <100.00%> (+0.04%) ⬆️
pymc/progress_bar.py 93.42% <100.00%> (+0.04%) ⬆️
pymc/sampling/mcmc.py 91.55% <100.00%> (+6.11%) ⬆️
pymc/sampling/population.py 65.47% <100.00%> (-5.36%) ⬇️
pymc/step_methods/compound.py 98.68% <100.00%> (+0.81%) ⬆️
pymc/step_methods/hmc/base_hmc.py 92.25% <ø> (ø)
pymc/step_methods/hmc/hmc.py 94.59% <ø> (ø)
... and 7 more

... and 21 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@michaelosthege
Copy link
Member

Taking a step back, would it make sense for a tune=None mode where the sampler(s) decide how much tune they need? In that case it would make sense for the individual steps to report back whether they're tuning or not.

Automatically stopping the warmup early would be nice. I think we should agree on cleanly separated definitions of warmup, burn-in and tuning. Samplers not needing to tune parameters doesn't mean that there's no need for a warmup phase of burn-in iterations (however one might call it).

Our current implementation is bad because it doesn't separate the concepts.

@aloctavodia
Copy link
Member

ArviZ does not require or use a "tune" stats anywhere.

@eclipse1605
Copy link
Author

#7997 (comment)

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like where this is going!

Using a slightly different naming I think we can simplify a bit more.

draw: Mapping[str, np.ndarray],
stats: Sequence[Mapping[str, Any]],
*,
tune: bool | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to call it in_warmup to align with the commonly understood PyMC/ArviZ terminology of warmup vs. posterior phase.

Additionally I think we should not set a default value.

The sampling iterators in sampling/mcmc.py (what you call "driver") should always pass the new parameter and that breaks any custom trace that does not expect it.

values = self._point_fn(draw)
value_dict = dict(zip(self.varnames, values))
stats_dict = self._statsbj.map(stats)
stats_dict["tune"] = bool(tune)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we call it in_warmup I can change this line in McBackend to look for it.

This has the advantage of not colliding with sampler-emitted "tune" stats which thereby become an optional, sampler-owned sampler-stat.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, ill make the change across drivers/backends/progress bar/tests.

@eclipse1605
Copy link
Author

@michaelosthege does this make sense?


point, stats = step.step(point)
trace.record(point, stats, tune=i < tune)
_record_with_in_warmup(trace, point, stats, in_warmup=i < tune)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inspecting the signature thousands of times creates a lot of unnecessary overhead.

I would prefer the breaking change of requiring custom trace backends to adapt immediately. (If there are any custom trace backends in existence at all.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, but for what its worth the signature inspection isn’t happening per draw; it’s cached per trace-backend class. still, i agree the hot path shouldn’t be doing any of this, so i can drop the compatibility shim and treat this as a breaking change. as you said, custom trace backends will need to update their record(..., *, in_warmup: bool) signature. does that sound ok?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let's do it. Less code to maintain.

@eclipse1605
Copy link
Author

@michaelosthege check this out

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with how the progress bar gets updated. Possibly my two comments on that matter are invalid, but please check them.

I'll also trigger the CI tests

return steps


def check_step_emits_tune(step: CompoundStep | BlockedStep):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole function can be removed

Comment on lines 330 to 351
def _progressbar_config(n_chains=1):
columns = [
TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)),
TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)),
TextColumn(
"{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1)
),
]

stats = {
"tune": [True] * n_chains,
"scaling": [0] * n_chains,
"accept_rate": [0.0] * n_chains,
}

return columns, stats

@staticmethod
def _make_progressbar_update_functions():
def update_stats(step_stats):
return {
"accept_rate" if key == "accept" else key: step_stats[key]
for key in ("tune", "accept", "scaling")
for key in ("accept", "scaling")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part can now use the in_warmup stat?

Comment on lines +207 to +214
stats = {"nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains}

return columns, stats

@staticmethod
def _make_progressbar_update_functions():
def update_stats(step_stats):
return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}}
return {key: step_stats[key] for key in {"nstep_out", "nstep_in"}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here?

Comment on lines 296 to 302
tune = mtrace._straces[0].get_sampler_stats("tune")
assert isinstance(tune, np.ndarray)
# warmup is tracked by the sampling driver
if discard_warmup:
assert tune.shape == (7, 3)
assert len(mtrace) == 7
else:
assert tune.shape == (12, 3)
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this test remain as before, but using the in_warmup stat instead?

@eclipse1605
Copy link
Author

I'm not familiar with how the progress bar gets updated. Possibly my two comments on that matter are invalid, but please check them.

hey, sorry for the delay but i think they're valid because warmup bookkeeping is now explicitly driver owned

@eclipse1605
Copy link
Author

@michaelosthege ive made the tests consistent with the changes, running the ci tests again will mostly pass now

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Thanks @eclipse1605 for your endurance with this!

@eclipse1605
Copy link
Author

Looks good to me!

Thanks @eclipse1605 for your endurance with this!

thanks a ton for the reviews and guidance @michaelosthege and @ricardoV94, really appreciate the patience since im still getting my bearings here :)


# Doesn't actually tune, but it's required to emit a sampler stat
# that indicates whether a draw was done in a tuning phase.
self.tune = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove self.tune

@staticmethod
def _progressbar_config(n_chains=1):
columns = [
TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)),
Copy link
Member

@ricardoV94 ricardoV94 Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want to show this information in the task bar, where is this defined now?

Copy link
Author

@eclipse1605 eclipse1605 Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want to show this information in the task bar, where is this defined now?

warmup now comes straight from ProgressBarManager, each task is initialised and updated with the driver’s in_warmup flag, so the column still renders. you can see the diff of progress_bar.py

test_dict = {
"posterior": ["u1", "n1"],
"sample_stats": ["~tune", "accept"],
"sample_stats": ["~in_warmup", "accept"],
Copy link
Member

@ricardoV94 ricardoV94 Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about changing the output variable name, this seems like a breaking change for users?

The specific line I pointed to may not be relevant. The general question is whether we changed anything in MultiTrace/InferenceData output with this PR other than the tune flag not existing per step.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it now writes the warmup flag once as in_warmup, but for users, nothing new shows up. when we persist sampler stats (e.g. in mcbackend) we store that boolean and keep trace.get_sampler_stats("tune") working by aliasing to the new field. the default NDArray backend still omits both names, just like before. and to_inference_data continues to drop whichever warmup marker exists, so the resulting InferenceData matches main; the test only switches the "absent" check to the new internal name. no other MultiTrace/InferenceData variables changed.

@ricardoV94 ricardoV94 requested a review from lucianopaz January 7, 2026 22:01
@eclipse1605 eclipse1605 requested a review from ricardoV94 January 8, 2026 19:10
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks sleek, I just want to do a manual integration test locally before merging

@eclipse1605
Copy link
Author

This looks sleek, I just want to do a manual integration test locally before merging

sounds good!

@eclipse1605
Copy link
Author

hey @ricardoV94 i tried to understand the failed test but didn't really get very far with it. is it failing because jax spits out NaNs when the dirichlet concentration is super skewed, so the multinomial never sees a clean prob vector?

@ricardoV94
Copy link
Member

That one fails now and then, don't worry about it

@eclipse1605
Copy link
Author

That one fails now and then, don't worry about it

ok sure :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: CategoricalGibbsMetropolis doesn't respect the tune parameter

4 participants