Skip to content

Commit 46c2c4a

Browse files
committed
Fix progressbar with nested compound step samplers
1 parent 9f3a119 commit 46c2c4a

File tree

6 files changed

+82
-59
lines changed

6 files changed

+82
-59
lines changed

pymc/step_methods/compound.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,11 @@ def _progressbar_config(n_chains=1):
189189
return columns, stats
190190

191191
@staticmethod
192-
def _make_update_stats_function():
193-
def update_stats(stats, step_stats, chain_idx):
194-
return stats
192+
def _make_update_stats_functions():
193+
def update_stats(step_stats):
194+
return step_stats
195195

196-
return update_stats
196+
return (update_stats,)
197197

198198
# Hack for creating the class correctly when unpickling.
199199
def __getnewargs_ex__(self):
@@ -332,16 +332,11 @@ def _progressbar_config(self, n_chains=1):
332332

333333
return columns, stats
334334

335-
def _make_update_stats_function(self):
336-
update_fns = [method._make_update_stats_function() for method in self.methods]
337-
338-
def update_stats(stats, step_stats, chain_idx):
339-
for step_stat, update_fn in zip(step_stats, update_fns):
340-
stats = update_fn(stats, step_stat, chain_idx)
341-
342-
return stats
343-
344-
return update_stats
335+
def _make_update_stats_functions(self):
336+
update_functions = []
337+
for method in self.methods:
338+
update_functions.extend(method._make_update_stats_functions())
339+
return update_functions
345340

346341

347342
def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]:

pymc/step_methods/hmc/nuts.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,19 +248,11 @@ def _progressbar_config(n_chains=1):
248248
return columns, stats
249249

250250
@staticmethod
251-
def _make_update_stats_function():
252-
def update_stats(stats, step_stats, chain_idx):
253-
if isinstance(step_stats, list):
254-
step_stats = step_stats[0]
251+
def _make_update_stats_functions():
252+
def update_stats(stats):
253+
return {key: stats[key] for key in ("diverging", "step_size", "tree_size")}
255254

256-
if not step_stats["tune"]:
257-
stats["divergences"][chain_idx] += step_stats["diverging"]
258-
259-
stats["step_size"][chain_idx] = step_stats["step_size"]
260-
stats["tree_size"][chain_idx] = step_stats["tree_size"]
261-
return stats
262-
263-
return update_stats
255+
return (update_stats,)
264256

265257

266258
# A proposal for the next position

pymc/step_methods/metropolis.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -346,18 +346,14 @@ def _progressbar_config(n_chains=1):
346346
return columns, stats
347347

348348
@staticmethod
349-
def _make_update_stats_function():
350-
def update_stats(stats, step_stats, chain_idx):
351-
if isinstance(step_stats, list):
352-
step_stats = step_stats[0]
353-
354-
stats["tune"][chain_idx] = step_stats["tune"]
355-
stats["accept_rate"][chain_idx] = step_stats["accept"]
356-
stats["scaling"][chain_idx] = step_stats["scaling"]
357-
358-
return stats
359-
360-
return update_stats
349+
def _make_update_stats_functions():
350+
def update_stats(step_stats):
351+
return {
352+
"accept_rate" if key == "accept" else key: step_stats[key]
353+
for key in ("tune", "accept", "scaling")
354+
}
355+
356+
return (update_stats,)
361357

362358

363359
def tune(scale, acc_rate):

pymc/step_methods/slicer.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,8 @@ def _progressbar_config(n_chains=1):
212212
return columns, stats
213213

214214
@staticmethod
215-
def _make_update_stats_function():
216-
def update_stats(stats, step_stats, chain_idx):
217-
if isinstance(step_stats, list):
218-
step_stats = step_stats[0]
215+
def _make_update_stats_functions():
216+
def update_stats(step_stats):
217+
return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}}
219218

220-
stats["tune"][chain_idx] = step_stats["tune"]
221-
stats["nstep_out"][chain_idx] = step_stats["nstep_out"]
222-
stats["nstep_in"][chain_idx] = step_stats["nstep_in"]
223-
224-
return stats
225-
226-
return update_stats
219+
return (update_stats,)

pymc/util.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -763,9 +763,8 @@ def __init__(
763763
progressbar=progressbar,
764764
progressbar_theme=progressbar_theme,
765765
)
766-
767766
self.progress_stats = progress_stats
768-
self.update_stats = step_method._make_update_stats_function()
767+
self.update_stats_functions = step_method._make_update_stats_functions()
769768

770769
self._show_progress = show_progress
771770
self.divergences = 0
@@ -829,27 +828,46 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
829828
if not tuning and stats and stats[0].get("diverging"):
830829
self.divergences += 1
831830

832-
self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
833-
more_updates = (
834-
{stat: value[chain_idx] for stat, value in self.progress_stats.items()}
835-
if self.full_stats
836-
else {}
837-
)
831+
if self.full_stats:
832+
# TODO: Index by chain already?
833+
chain_progress_stats = [
834+
update_states_fn(step_stats)
835+
for update_states_fn, step_stats in zip(
836+
self.update_stats_functions, stats, strict=True
837+
)
838+
]
839+
all_step_stats = {}
840+
for step_stats in chain_progress_stats:
841+
for key, val in step_stats.items():
842+
if key in all_step_stats:
843+
# TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now
844+
continue
845+
else:
846+
all_step_stats[key] = val
847+
848+
else:
849+
all_step_stats = {}
850+
851+
# more_updates = (
852+
# {stat: value[chain_idx] for stat, value in progress_stats.items()}
853+
# if self.full_stats
854+
# else {}
855+
# )
838856

839857
self._progress.update(
840858
self.tasks[chain_idx],
841859
completed=draw,
842860
draws=draw,
843861
sampling_speed=speed,
844862
speed_unit=unit,
845-
**more_updates,
863+
**all_step_stats,
846864
)
847865

848866
if is_last:
849867
self._progress.update(
850868
self.tasks[chain_idx],
851869
draws=draw + 1 if not self.combined_progress else draw,
852-
**more_updates,
870+
**all_step_stats,
853871
refresh=True,
854872
)
855873

tests/test_util.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,32 @@ def test_get_value_vars_from_user_vars():
250250
get_value_vars_from_user_vars([x2], model1)
251251
with pytest.raises(ValueError, match=rf"{prefix} \['det2'\]"):
252252
get_value_vars_from_user_vars([det2], model2)
253+
254+
255+
def test_progressbar_nested_compound():
256+
# Regression test for https://github.com/pymc-devs/pymc/issues/7721
257+
258+
with pm.Model():
259+
a = pm.Poisson("a", mu=10)
260+
b = pm.Binomial("b", n=a, p=0.8)
261+
c = pm.Poisson("c", mu=11)
262+
d = pm.Dirichlet("d", a=[c, b])
263+
264+
step = pm.CompoundStep(
265+
[
266+
pm.CompoundStep([pm.Metropolis(a), pm.Metropolis(b), pm.Metropolis(c)]),
267+
pm.NUTS([d]),
268+
]
269+
)
270+
271+
kwargs = {
272+
"draws": 10,
273+
"tune": 10,
274+
"chains": 2,
275+
"compute_convergence_checks": False,
276+
"step": step,
277+
"progressbar": True,
278+
}
279+
280+
pm.sample(**kwargs, cores=1)
281+
pm.sample(**kwargs, cores=2)

0 commit comments

Comments
 (0)