1919from collections import namedtuple
2020from collections .abc import Iterable , Sequence
2121from copy import deepcopy
22- from typing import NewType , cast
22+ from typing import TYPE_CHECKING , Literal , NewType , cast
2323
2424import arviz
2525import cloudpickle
4646
4747from pymc .exceptions import BlockModelAccessError
4848
49+ if TYPE_CHECKING :
50+ from pymc import BlockedStep
51+
52+
53+ ProgressType = Literal [
54+ "chain" ,
55+ "combined" ,
56+ "simple" ,
57+ "full" ,
58+ "combined+full" ,
59+ "full+combined" ,
60+ "combined+simple" ,
61+ "simple+combined" ,
62+ "chain+full" ,
63+ "full+chain" ,
64+ ]
65+
4966
5067def __getattr__ (name ):
5168 if name == "dataset_to_point_list" :
@@ -639,6 +656,7 @@ def make_tasks_table(self, tasks: Iterable[Task]) -> Table:
639656 """
640657
641658 def call_column (column , task ):
659+ # Subclass rich.BarColumn and add a callback method to dynamically update the display
642660 if hasattr (column , "callbacks" ):
643661 column .callbacks (task )
644662
@@ -681,6 +699,8 @@ def call_column(column, task):
681699
682700
683701class DivergenceBarColumn (BarColumn ):
702+ """Rich colorbar that changes color when a chain has detected a divergence."""
703+
684704 def __init__ (self , * args , diverging_color = "red" , ** kwargs ):
685705 from matplotlib .colors import to_rgb
686706
@@ -703,7 +723,53 @@ def callbacks(self, task: "Task"):
703723
704724
705725class ProgressManager :
706- def __init__ (self , step_method , chains , draws , tune , progressbar , progressbar_theme ):
726+ """Manage progress bars displayed during sampling."""
727+
728+ def __init__ (
729+ self ,
730+ step_method : BlockedStep ,
731+ chains : int ,
732+ draws : int ,
733+ tune : int ,
734+ progressbar : bool | ProgressType = True ,
735+ progressbar_theme : Theme = default_progress_theme ,
736+ ):
737+ """
738+ Manage progress bars displayed during sampling.
739+
740+ When sampling, Step classes are responsible for computing and exposing statistics that can be reported on
741+ progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config`
742+ and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which
743+ columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics
744+ that will be displayed on the progress bar.
745+
746+ Parameters
747+ ----------
748+ step_method: BlockedStep
749+ The step method being used to sample
750+ chains: int
751+ Number of chains being sampled
752+ draws: int
753+ Number of draws per chain
754+ tune: int
755+ Number of tuning steps per chain
756+ progressbar: bool or ProgressType, optional
757+ How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
758+ for either:
759+ - "combined": A single progress bar that displays the progress of all chains combined.
760+ - "chain": A separate progress bar for each chain.
761+
762+ You can also combine the above options with:
763+ - "simple": A simple progress bar that displays only timing information alongside the progress bar.
764+ - "full": A progress bar that displays all available statistics.
765+
766+ These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple".
767+
768+ If True, the default is "chain+full".
769+
770+ progressbar_theme: Theme, optional
771+ The theme to use for the progress bar. Defaults to the default theme.
772+ """
707773 self .combined_progress = False
708774 self .full_stats = True
709775 show_progress = True
0 commit comments