-
Notifications
You must be signed in to change notification settings - Fork 2.2k
attempt to fix warmup bookkeeping: dropped the tune stat #8015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@OriolAbril / @aloctavodia does any part of Arviz require the step samples to have a tune flag? Is it enough that we have warmup / posterior distinction, each with their number of draws? |
pymc/step_methods/compound.py
Outdated
| return steps | ||
|
|
||
|
|
||
| def check_step_emits_tune(step: CompoundStep | BlockedStep): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this ends up working without tune, remove this function as well
|
Taking a step back, would it make sense for a Even if that's the case, I think it still makes sense to remove this currently useless stat and reintroduce in a separate PR (provided nobody finds a reason why it is actually useful/needed). |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8015 +/- ##
==========================================
+ Coverage 90.22% 91.48% +1.26%
==========================================
Files 116 116
Lines 18972 19059 +87
==========================================
+ Hits 17117 17437 +320
+ Misses 1855 1622 -233
🚀 New features to boost your workflow:
|
Automatically stopping the warmup early would be nice. I think we should agree on cleanly separated definitions of warmup, burn-in and tuning. Samplers not needing to tune parameters doesn't mean that there's no need for a warmup phase of burn-in iterations (however one might call it). Our current implementation is bad because it doesn't separate the concepts. |
|
ArviZ does not require or use a "tune" stats anywhere. |
michaelosthege
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like where this is going!
Using a slightly different naming I think we can simplify a bit more.
pymc/backends/base.py
Outdated
| draw: Mapping[str, np.ndarray], | ||
| stats: Sequence[Mapping[str, Any]], | ||
| *, | ||
| tune: bool | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to call it in_warmup to align with the commonly understood PyMC/ArviZ terminology of warmup vs. posterior phase.
Additionally I think we should not set a default value.
The sampling iterators in sampling/mcmc.py (what you call "driver") should always pass the new parameter and that breaks any custom trace that does not expect it.
pymc/backends/mcbackend.py
Outdated
| values = self._point_fn(draw) | ||
| value_dict = dict(zip(self.varnames, values)) | ||
| stats_dict = self._statsbj.map(stats) | ||
| stats_dict["tune"] = bool(tune) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we call it in_warmup I can change this line in McBackend to look for it.
This has the advantage of not colliding with sampler-emitted "tune" stats which thereby become an optional, sampler-owned sampler-stat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, ill make the change across drivers/backends/progress bar/tests.
|
@michaelosthege does this make sense? |
pymc/sampling/mcmc.py
Outdated
|
|
||
| point, stats = step.step(point) | ||
| trace.record(point, stats, tune=i < tune) | ||
| _record_with_in_warmup(trace, point, stats, in_warmup=i < tune) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inspecting the signature thousands of times creates a lot of unnecessary overhead.
I would prefer the breaking change of requiring custom trace backends to adapt immediately. (If there are any custom trace backends in existence at all.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true, but for what its worth the signature inspection isn’t happening per draw; it’s cached per trace-backend class. still, i agree the hot path shouldn’t be doing any of this, so i can drop the compatibility shim and treat this as a breaking change. as you said, custom trace backends will need to update their record(..., *, in_warmup: bool) signature. does that sound ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes let's do it. Less code to maintain.
|
@michaelosthege check this out |
michaelosthege
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with how the progress bar gets updated. Possibly my two comments on that matter are invalid, but please check them.
I'll also trigger the CI tests
pymc/step_methods/compound.py
Outdated
| return steps | ||
|
|
||
|
|
||
| def check_step_emits_tune(step: CompoundStep | BlockedStep): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the whole function can be removed
| def _progressbar_config(n_chains=1): | ||
| columns = [ | ||
| TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), | ||
| TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), | ||
| TextColumn( | ||
| "{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1) | ||
| ), | ||
| ] | ||
|
|
||
| stats = { | ||
| "tune": [True] * n_chains, | ||
| "scaling": [0] * n_chains, | ||
| "accept_rate": [0.0] * n_chains, | ||
| } | ||
|
|
||
| return columns, stats | ||
|
|
||
| @staticmethod | ||
| def _make_progressbar_update_functions(): | ||
| def update_stats(step_stats): | ||
| return { | ||
| "accept_rate" if key == "accept" else key: step_stats[key] | ||
| for key in ("tune", "accept", "scaling") | ||
| for key in ("accept", "scaling") | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part can now use the in_warmup stat?
| stats = {"nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} | ||
|
|
||
| return columns, stats | ||
|
|
||
| @staticmethod | ||
| def _make_progressbar_update_functions(): | ||
| def update_stats(step_stats): | ||
| return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} | ||
| return {key: step_stats[key] for key in {"nstep_out", "nstep_in"}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here?
| tune = mtrace._straces[0].get_sampler_stats("tune") | ||
| assert isinstance(tune, np.ndarray) | ||
| # warmup is tracked by the sampling driver | ||
| if discard_warmup: | ||
| assert tune.shape == (7, 3) | ||
| assert len(mtrace) == 7 | ||
| else: | ||
| assert tune.shape == (12, 3) | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this test remain as before, but using the in_warmup stat instead?
hey, sorry for the delay but i think they're valid because warmup bookkeeping is now explicitly driver owned |
|
@michaelosthege ive made the tests consistent with the changes, running the ci tests again will mostly pass now |
michaelosthege
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Thanks @eclipse1605 for your endurance with this!
thanks a ton for the reviews and guidance @michaelosthege and @ricardoV94, really appreciate the patience since im still getting my bearings here :) |
pymc/step_methods/metropolis.py
Outdated
|
|
||
| # Doesn't actually tune, but it's required to emit a sampler stat | ||
| # that indicates whether a draw was done in a tuning phase. | ||
| self.tune = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove self.tune
| @staticmethod | ||
| def _progressbar_config(n_chains=1): | ||
| columns = [ | ||
| TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still want to show this information in the task bar, where is this defined now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still want to show this information in the task bar, where is this defined now?
warmup now comes straight from ProgressBarManager, each task is initialised and updated with the driver’s in_warmup flag, so the column still renders. you can see the diff of progress_bar.py
| test_dict = { | ||
| "posterior": ["u1", "n1"], | ||
| "sample_stats": ["~tune", "accept"], | ||
| "sample_stats": ["~in_warmup", "accept"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about changing the output variable name, this seems like a breaking change for users?
The specific line I pointed to may not be relevant. The general question is whether we changed anything in MultiTrace/InferenceData output with this PR other than the tune flag not existing per step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it now writes the warmup flag once as in_warmup, but for users, nothing new shows up. when we persist sampler stats (e.g. in mcbackend) we store that boolean and keep trace.get_sampler_stats("tune") working by aliasing to the new field. the default NDArray backend still omits both names, just like before. and to_inference_data continues to drop whichever warmup marker exists, so the resulting InferenceData matches main; the test only switches the "absent" check to the new internal name. no other MultiTrace/InferenceData variables changed.
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks sleek, I just want to do a manual integration test locally before merging
sounds good! |
|
hey @ricardoV94 i tried to understand the failed test but didn't really get very far with it. is it failing because jax spits out NaNs when the dirichlet concentration is super skewed, so the multinomial never sees a clean prob vector? |
|
That one fails now and then, don't worry about it |
ok sure :) |
Description
I tried to make “tuning vs draws” a driver owned concept again. Right now, parts of sampling/postprocessing infer warmup length from a per-step
"tune"sampler stat, which can get out of sync (e.g. a step method returning"tune": Falseeverywhere makes PyMC thinkn_tune == 0, so warmup isn’t discarded and the logs look wrong).Related Issues
Fixes: #7997
Context: #7776 (progressbar/stat refactor that exposed the mismatch)
Related discussion/attempts: #7730, #7721, #7724, #8014