2121import warnings
2222
2323from abc import ABC , abstractmethod
24- from collections .abc import Iterable , Mapping , Sequence
24+ from collections .abc import Callable , Iterable , Mapping , Sequence
2525from dataclasses import field
2626from enum import IntEnum , unique
2727from typing import Any
2828
2929import numpy as np
3030
3131from pytensor .graph .basic import Variable
32+ from rich .progress import ProgressColumn
3233
3334from pymc .blocking import PointType , StatDtype , StatsDict , StatShape , StatsType
3435from pymc .model import modelcontext
@@ -181,17 +182,72 @@ def __new__(cls, *args, **kwargs):
181182 step .__newargs = (vars , * args ), kwargs
182183 return step
183184
184- @staticmethod
185- def _progressbar_config (n_chains = 1 ):
185+ def _progressbar_config (self , n_chains : int = 1 ):
186+ """
187+ Get progressbar configuration for this step sampler.
188+
189+ By default, the progress bar displays no stats columns, only basic info (number of draws and sampling time).
190+ Specific step methods should overload this method to specify which stats to display and how.
191+
192+ Parameters
193+ ----------
194+ n_chains: int
195+ Number of chains being sampled. This controls the number of progress bars that will be displayed.
196+
197+ Returns
198+ -------
199+ columns: list of rich.progress.ProgressColumn
200+ List of columns to display in the progress bar.
201+
202+ stats: dict
203+ Dictionary of statistics associated with each column.
204+ """
186205 columns = []
187206 stats = {}
188207
189208 return columns , stats
190209
191- @staticmethod
192- def _make_update_stats_function ():
193- def update_stats (stats , step_stats , chain_idx ):
194- return stats
210+ def _make_update_stats_function (self ) -> Callable [[dict , dict , int ], dict ]:
211+ """
212+ Create an update function used by the progress bar to update statistics during sampling.
213+
214+ By default, the update is a no-op. Specific step methods should implement special logic for which
215+ statistics to display and how.
216+
217+ Returns
218+ -------
219+ update_stats: Callable
220+ Function that updates displayed statistics for the current chain, given statistics generated by the step
221+ during the most recent step.
222+ """
223+
224+ def update_stats (
225+ displayed_stats : dict [str , np .ndarray ],
226+ step_stats : dict [str , str | float | int | bool | None ],
227+ chain_idx : int ,
228+ ) -> dict [str , np .ndarray ]:
229+ """
230+ Update the statistics displayed in the progress bar after each step.
231+
232+ Parameters
233+ ----------
234+ displayed_stats: dict
235+ Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
236+ the values are the current values of the statistics, with one value per chain being sampled.
237+
238+ step_stats: dict
239+ Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
240+ names of the statistics and the values are the values of the statistics generated by the step sampler.
241+
242+ chain_idx: int
243+ The chain number associated with the current step
244+
245+ Returns
246+ -------
247+ dict
248+ The updated statistics dictionary to be displayed in the progress bar.
249+ """
250+ return displayed_stats
195251
196252 return update_stats
197253
@@ -311,7 +367,28 @@ def set_rng(self, rng: RandomGenerator):
311367 for method , _rng in zip (self .methods , _rngs ):
312368 method .set_rng (_rng )
313369
314- def _progressbar_config (self , n_chains = 1 ):
370+ def _progressbar_config (
371+ self , n_chains : int = 1
372+ ) -> tuple [list [ProgressColumn ], dict [str , np .ndarray | float ]]:
373+ """
374+ Get progressbar configuration for this step sampler.
375+
376+ The columns of the rich progress bar displayed during sampler are chosen by the step samplers themselves. In
377+ the compound step case, we display the set union of all columns from the sub-step samplers.
378+
379+ Parameters
380+ ----------
381+ n_chains: int
382+ Number of chains being sampled. This controls the number of progress bars that will be displayed.
383+
384+ Returns
385+ -------
386+ columns: list of rich.progress.ProgressColumn
387+ List of columns to display in the progress bar.
388+
389+ stats: dict
390+ Dictionary of statistics associated with each column.
391+ """
315392 from functools import reduce
316393
317394 column_lists , stat_dict_list = zip (
@@ -332,14 +409,56 @@ def _progressbar_config(self, n_chains=1):
332409
333410 return columns , stats
334411
335- def _make_update_stats_function (self ):
412+ def _make_update_stats_function (self ) -> Callable [[dict , list [dict ], int ], dict ]:
413+ """
414+ Create an update function used by the progress bar to update statistics during sampling.
415+
416+ Returns
417+ -------
418+ update_stats: Callable
419+ Function that updates displayed statistics for the current chain, given statistics generated by the step
420+ during the most recent step.
421+ """
336422 update_fns = [method ._make_update_stats_function () for method in self .methods ]
337423
338- def update_stats (stats , step_stats , chain_idx ):
424+ def update_stats (
425+ displayed_stats : dict [str , np .ndarray ],
426+ step_stats : list [dict [str , str | float | int | bool | None ]],
427+ chain_idx : int ,
428+ ) -> dict [str , np .ndarray ]:
429+ """
430+ Update the statistics displayed in the progress bar after each step.
431+
432+ Parameters
433+ ----------
434+ displayed_stats: dict
435+ Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
436+ the values are the current values of the statistics, with one value per chain being sampled.
437+
438+ step_stats: list of dict
439+ List of dictionaries containing statistics generated by **each** step sampler in the CompoundStep when
440+ taking the current step. For each dictionary, the keys are names of statistics and the values are
441+ the values of the statistics generated by the step sampler.
442+
443+ chain_idx: int
444+ The chain number associated with the current step
445+
446+ Returns
447+ -------
448+ dict
449+ The updated statistics dictionary to be displayed in the progress bar.
450+ """
451+ # TODO: The compound step is commonly made of many instances of the same step (e.g. 3 Metropolis steps).
452+ # In this case, the current loop logic is just overriding each Metropolis steps' stats with those of the
453+ # next step (so the user only ever sees the 3rd step's stats). We should have a better way to aggregate
454+ # the stats from each step.
455+ if not isinstance (step_stats , list ):
456+ step_stats = [step_stats ]
457+
339458 for step_stat , update_fn in zip (step_stats , update_fns ):
340- stats = update_fn (stats , step_stat , chain_idx )
459+ displayed_stats = update_fn (displayed_stats , step_stat , chain_idx )
341460
342- return stats
461+ return displayed_stats
343462
344463 return update_stats
345464
0 commit comments