Skip to content

Commit a8af2e8

Browse files
Document how step methods provide progress bar stats
1 parent af81955 commit a8af2e8

File tree

4 files changed

+296
-41
lines changed

4 files changed

+296
-41
lines changed

pymc/step_methods/compound.py

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121
import warnings
2222

2323
from abc import ABC, abstractmethod
24-
from collections.abc import Iterable, Mapping, Sequence
24+
from collections.abc import Callable, Iterable, Mapping, Sequence
2525
from dataclasses import field
2626
from enum import IntEnum, unique
2727
from typing import Any
2828

2929
import numpy as np
3030

3131
from pytensor.graph.basic import Variable
32+
from rich.progress import ProgressColumn
3233

3334
from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType
3435
from pymc.model import modelcontext
@@ -181,17 +182,72 @@ def __new__(cls, *args, **kwargs):
181182
step.__newargs = (vars, *args), kwargs
182183
return step
183184

184-
@staticmethod
185-
def _progressbar_config(n_chains=1):
185+
def _progressbar_config(self, n_chains: int = 1):
186+
"""
187+
Get progressbar configuration for this step sampler.
188+
189+
By default, the progress bar displays no stats columns, only basic info (number of draws and sampling time).
190+
Specific step methods should overload this method to specify which stats to display and how.
191+
192+
Parameters
193+
----------
194+
n_chains: int
195+
Number of chains being sampled. This controls the number of progress bars that will be displayed.
196+
197+
Returns
198+
-------
199+
columns: list of rich.progress.ProgressColumn
200+
List of columns to display in the progress bar.
201+
202+
stats: dict
203+
Dictionary of statistics associated with each column.
204+
"""
186205
columns = []
187206
stats = {}
188207

189208
return columns, stats
190209

191-
@staticmethod
192-
def _make_update_stats_function():
193-
def update_stats(stats, step_stats, chain_idx):
194-
return stats
210+
def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
211+
"""
212+
Create an update function used by the progress bar to update statistics during sampling.
213+
214+
By default, the update is a no-op. Specific step methods should implement special logic for which
215+
statistics to display and how.
216+
217+
Returns
218+
-------
219+
update_stats: Callable
220+
Function that updates displayed statistics for the current chain, given statistics generated by the step
221+
during the most recent step.
222+
"""
223+
224+
def update_stats(
225+
displayed_stats: dict[str, np.ndarray],
226+
step_stats: dict[str, str | float | int | bool | None],
227+
chain_idx: int,
228+
) -> dict[str, np.ndarray]:
229+
"""
230+
Update the statistics displayed in the progress bar after each step.
231+
232+
Parameters
233+
----------
234+
displayed_stats: dict
235+
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
236+
the values are the current values of the statistics, with one value per chain being sampled.
237+
238+
step_stats: dict
239+
Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
240+
names of the statistics and the values are the values of the statistics generated by the step sampler.
241+
242+
chain_idx: int
243+
The chain number associated with the current step
244+
245+
Returns
246+
-------
247+
dict
248+
The updated statistics dictionary to be displayed in the progress bar.
249+
"""
250+
return displayed_stats
195251

196252
return update_stats
197253

@@ -311,7 +367,28 @@ def set_rng(self, rng: RandomGenerator):
311367
for method, _rng in zip(self.methods, _rngs):
312368
method.set_rng(_rng)
313369

314-
def _progressbar_config(self, n_chains=1):
370+
def _progressbar_config(
371+
self, n_chains: int = 1
372+
) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]:
373+
"""
374+
Get progressbar configuration for this step sampler.
375+
376+
The columns of the rich progress bar displayed during sampler are chosen by the step samplers themselves. In
377+
the compound step case, we display the set union of all columns from the sub-step samplers.
378+
379+
Parameters
380+
----------
381+
n_chains: int
382+
Number of chains being sampled. This controls the number of progress bars that will be displayed.
383+
384+
Returns
385+
-------
386+
columns: list of rich.progress.ProgressColumn
387+
List of columns to display in the progress bar.
388+
389+
stats: dict
390+
Dictionary of statistics associated with each column.
391+
"""
315392
from functools import reduce
316393

317394
column_lists, stat_dict_list = zip(
@@ -332,14 +409,56 @@ def _progressbar_config(self, n_chains=1):
332409

333410
return columns, stats
334411

335-
def _make_update_stats_function(self):
412+
def _make_update_stats_function(self) -> Callable[[dict, list[dict], int], dict]:
413+
"""
414+
Create an update function used by the progress bar to update statistics during sampling.
415+
416+
Returns
417+
-------
418+
update_stats: Callable
419+
Function that updates displayed statistics for the current chain, given statistics generated by the step
420+
during the most recent step.
421+
"""
336422
update_fns = [method._make_update_stats_function() for method in self.methods]
337423

338-
def update_stats(stats, step_stats, chain_idx):
424+
def update_stats(
425+
displayed_stats: dict[str, np.ndarray],
426+
step_stats: list[dict[str, str | float | int | bool | None]],
427+
chain_idx: int,
428+
) -> dict[str, np.ndarray]:
429+
"""
430+
Update the statistics displayed in the progress bar after each step.
431+
432+
Parameters
433+
----------
434+
displayed_stats: dict
435+
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
436+
the values are the current values of the statistics, with one value per chain being sampled.
437+
438+
step_stats: list of dict
439+
List of dictionaries containing statistics generated by **each** step sampler in the CompoundStep when
440+
taking the current step. For each dictionary, the keys are names of statistics and the values are
441+
the values of the statistics generated by the step sampler.
442+
443+
chain_idx: int
444+
The chain number associated with the current step
445+
446+
Returns
447+
-------
448+
dict
449+
The updated statistics dictionary to be displayed in the progress bar.
450+
"""
451+
# TODO: The compound step is commonly made of many instances of the same step (e.g. 3 Metropolis steps).
452+
# In this case, the current loop logic is just overriding each Metropolis steps' stats with those of the
453+
# next step (so the user only ever sees the 3rd step's stats). We should have a better way to aggregate
454+
# the stats from each step.
455+
if not isinstance(step_stats, list):
456+
step_stats = [step_stats]
457+
339458
for step_stat, update_fn in zip(step_stats, update_fns):
340-
stats = update_fn(stats, step_stat, chain_idx)
459+
displayed_stats = update_fn(displayed_stats, step_stat, chain_idx)
341460

342-
return stats
461+
return displayed_stats
343462

344463
return update_stats
345464

pymc/step_methods/hmc/nuts.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from __future__ import annotations
1616

1717
from collections import namedtuple
18+
from collections.abc import Callable
1819
from dataclasses import field
1920

2021
import numpy as np
2122

2223
from pytensor import config
23-
from rich.progress import TextColumn
24+
from rich.progress import ProgressColumn, TextColumn
2425
from rich.table import Column
2526

2627
from pymc.stats.convergence import SamplerWarning
@@ -231,8 +232,25 @@ def competence(var, has_grad):
231232
return Competence.PREFERRED
232233
return Competence.INCOMPATIBLE
233234

234-
@staticmethod
235-
def _progressbar_config(n_chains=1):
235+
def _progressbar_config(
236+
self, n_chains: int = 1
237+
) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]:
238+
"""
239+
Get progressbar configuration for this step sampler.
240+
241+
Parameters
242+
----------
243+
n_chains: int
244+
Number of chains being sampled. This controls the number of progress bars that will be displayed.
245+
246+
Returns
247+
-------
248+
columns: list of rich.progress.ProgressColumn
249+
List of columns to display in the progress bar.
250+
251+
stats: dict
252+
Dictionary of statistics associated with each column.
253+
"""
236254
columns = [
237255
TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)),
238256
TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)),
@@ -247,18 +265,52 @@ def _progressbar_config(n_chains=1):
247265

248266
return columns, stats
249267

250-
@staticmethod
251-
def _make_update_stats_function():
252-
def update_stats(stats, step_stats, chain_idx):
268+
def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
269+
"""
270+
Create an update function used by the progress bar to update statistics during sampling.
271+
272+
Returns
273+
-------
274+
update_stats: Callable
275+
Function that updates displayed statistics for the current chain, given statistics generated by the step
276+
during the most recent step.
277+
"""
278+
279+
def update_stats(
280+
displayed_stats: dict[str, np.ndarray],
281+
step_stats: dict[str, str | float | int | bool | None],
282+
chain_idx: int,
283+
) -> dict[str, np.ndarray]:
284+
"""
285+
Update the statistics displayed in the progress bar after each step.
286+
287+
Parameters
288+
----------
289+
displayed_stats: dict
290+
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
291+
the values are the current values of the statistics, with one value per chain being sampled.
292+
293+
step_stats: dict
294+
Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
295+
names of the statistics and the values are the values of the statistics generated by the step sampler.
296+
297+
chain_idx: int
298+
The chain number associated with the current step
299+
300+
Returns
301+
-------
302+
dict
303+
The updated statistics dictionary to be displayed in the progress bar.
304+
"""
253305
if isinstance(step_stats, list):
254306
step_stats = step_stats[0]
255307

256308
if not step_stats["tune"]:
257-
stats["divergences"][chain_idx] += step_stats["diverging"]
309+
displayed_stats["divergences"][chain_idx] += step_stats["diverging"]
258310

259-
stats["step_size"][chain_idx] = step_stats["step_size"]
260-
stats["tree_size"][chain_idx] = step_stats["tree_size"]
261-
return stats
311+
displayed_stats["step_size"][chain_idx] = step_stats["step_size"]
312+
displayed_stats["tree_size"][chain_idx] = step_stats["tree_size"]
313+
return displayed_stats
262314

263315
return update_stats
264316

pymc/step_methods/metropolis.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytensor import tensor as pt
2525
from pytensor.graph.fg import MissingInputError
2626
from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV
27-
from rich.progress import TextColumn
27+
from rich.progress import ProgressColumn, TextColumn
2828
from rich.table import Column
2929

3030
import pymc as pm
@@ -327,8 +327,25 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
327327
def competence(var, has_grad):
328328
return Competence.COMPATIBLE
329329

330-
@staticmethod
331-
def _progressbar_config(n_chains=1):
330+
def _progressbar_config(
331+
self, n_chains: int = 1
332+
) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]:
333+
"""
334+
Get progressbar configuration for this step sampler.
335+
336+
Parameters
337+
----------
338+
n_chains: int
339+
Number of chains being sampled. This controls the number of progress bars that will be displayed.
340+
341+
Returns
342+
-------
343+
columns: list of rich.progress.ProgressColumn
344+
List of columns to display in the progress bar.
345+
346+
stats: dict
347+
Dictionary of statistics associated with each column.
348+
"""
332349
columns = [
333350
TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)),
334351
TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)),
@@ -345,17 +362,51 @@ def _progressbar_config(n_chains=1):
345362

346363
return columns, stats
347364

348-
@staticmethod
349-
def _make_update_stats_function():
350-
def update_stats(stats, step_stats, chain_idx):
365+
def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
366+
"""
367+
Create an update function used by the progress bar to update statistics during sampling.
368+
369+
Returns
370+
-------
371+
update_stats: Callable
372+
Function that updates displayed statistics for the current chain, given statistics generated by the step
373+
during the most recent step.
374+
"""
375+
376+
def update_stats(
377+
displayed_stats: dict[str, np.ndarray],
378+
step_stats: dict[str, str | float | int | bool | None],
379+
chain_idx: int,
380+
) -> dict[str, np.ndarray]:
381+
"""
382+
Update the statistics displayed in the progress bar after each step.
383+
384+
Parameters
385+
----------
386+
displayed_stats: dict
387+
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
388+
the values are the current values of the statistics, with one value per chain being sampled.
389+
390+
step_stats: dict
391+
Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
392+
names of the statistics and the values are the values of the statistics generated by the step sampler.
393+
394+
chain_idx: int
395+
The chain number associated with the current step
396+
397+
Returns
398+
-------
399+
dict
400+
The updated statistics dictionary to be displayed in the progress bar.
401+
"""
351402
if isinstance(step_stats, list):
352403
step_stats = step_stats[0]
353404

354-
stats["tune"][chain_idx] = step_stats["tune"]
355-
stats["accept_rate"][chain_idx] = step_stats["accept"]
356-
stats["scaling"][chain_idx] = step_stats["scaling"]
405+
displayed_stats["tune"][chain_idx] = step_stats["tune"]
406+
displayed_stats["accept_rate"][chain_idx] = step_stats["accept"]
407+
displayed_stats["scaling"][chain_idx] = step_stats["scaling"]
357408

358-
return stats
409+
return displayed_stats
359410

360411
return update_stats
361412

0 commit comments

Comments
 (0)