3636from arviz import InferenceData , dict_to_dataset
3737from arviz .data .base import make_attrs
3838from pytensor .graph .basic import Variable
39- from rich .console import Console
40- from rich .progress import BarColumn , TextColumn , TimeElapsedColumn , TimeRemainingColumn
4139from rich .theme import Theme
4240from threadpoolctl import threadpool_limits
4341from typing_extensions import Protocol
6765from pymc .step_methods .arraystep import BlockedStep , PopulationArrayStepShared
6866from pymc .step_methods .hmc import quadpotential
6967from pymc .util import (
70- CustomProgress ,
68+ ProgressBarManager ,
69+ ProgressBarType ,
7170 RandomSeed ,
7271 RandomState ,
7372 _get_seeds_per_chain ,
@@ -278,7 +277,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None:
278277 else :
279278 varnames = ", " .join (
280279 [
281- get_untransformed_name (v .name ) if is_transformed_name (v .name ) else v .name
280+ get_untransformed_name (v .name ) if is_transformed_name (v .name ) else v .name # type: ignore[misc]
282281 for v in s .vars
283282 ]
284283 )
@@ -425,7 +424,7 @@ def sample(
425424 chains : int | None = None ,
426425 cores : int | None = None ,
427426 random_seed : RandomState = None ,
428- progressbar : bool = True ,
427+ progressbar : bool | ProgressBarType = True ,
429428 progressbar_theme : Theme | None = default_progress_theme ,
430429 step = None ,
431430 var_names : Sequence [str ] | None = None ,
@@ -457,7 +456,7 @@ def sample(
457456 chains : int | None = None ,
458457 cores : int | None = None ,
459458 random_seed : RandomState = None ,
460- progressbar : bool = True ,
459+ progressbar : bool | ProgressBarType = True ,
461460 progressbar_theme : Theme | None = default_progress_theme ,
462461 step = None ,
463462 var_names : Sequence [str ] | None = None ,
@@ -489,8 +488,8 @@ def sample(
489488 chains : int | None = None ,
490489 cores : int | None = None ,
491490 random_seed : RandomState = None ,
492- progressbar : bool = True ,
493- progressbar_theme : Theme | None = default_progress_theme ,
491+ progressbar : bool | ProgressBarType = True ,
492+ progressbar_theme : Theme | None = None ,
494493 step = None ,
495494 var_names : Sequence [str ] | None = None ,
496495 nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
@@ -540,11 +539,18 @@ def sample(
540539 A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
541540 We no longer support ``RandomState`` objects because their seeding mechanism does not allow
542541 easy spawning of new independent random streams that are needed by the step methods.
543- progressbar : bool, optional default=True
544- Whether or not to display a progress bar in the command line. The bar shows the percentage
545- of completion, the sampling speed in samples per second (SPS), and the estimated remaining
546- time until completion ("expected time of arrival"; ETA).
547- Only applicable to the pymc nuts sampler.
542+ progressbar: bool or ProgressType, optional
543+ How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
544+ for one of the following:
545+ - "combined": A single progress bar that displays the total progress across all chains. Only timing
546+ information is shown.
547+ - "split": A separate progress bar for each chain. Only timing information is shown.
548+ - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
549+ chains. Aggregate sample statistics are also displayed.
550+ - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
551+ are also displayed.
552+
553+ If True, the default is "split+stats" is used.
548554 step : function or iterable of functions
549555 A step function or collection of functions. If there are variables without step methods,
550556 step methods for those variables will be assigned automatically. By default the NUTS step
@@ -710,6 +716,10 @@ def sample(
710716 if isinstance (trace , list ):
711717 raise ValueError ("Please use `var_names` keyword argument for partial traces." )
712718
719+ # progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and
720+ # ADVI initialization expect just a bool.
721+ progress_bool = bool (progressbar )
722+
713723 model = modelcontext (model )
714724 if not model .free_RVs :
715725 raise SamplingError (
@@ -806,7 +816,7 @@ def joined_blas_limiter():
806816 initvals = initvals ,
807817 model = model ,
808818 var_names = var_names ,
809- progressbar = progressbar ,
819+ progressbar = progress_bool ,
810820 idata_kwargs = idata_kwargs ,
811821 compute_convergence_checks = compute_convergence_checks ,
812822 nuts_sampler_kwargs = nuts_sampler_kwargs ,
@@ -825,7 +835,7 @@ def joined_blas_limiter():
825835 n_init = n_init ,
826836 model = model ,
827837 random_seed = random_seed_list ,
828- progressbar = progressbar ,
838+ progressbar = progress_bool ,
829839 jitter_max_retries = jitter_max_retries ,
830840 tune = tune ,
831841 initvals = initvals ,
@@ -1139,34 +1149,44 @@ def _sample_many(
11391149 Step function
11401150 """
11411151 initial_step_state = step .sampling_state
1142- for i in range (chains ):
1143- step .sampling_state = initial_step_state
1144- _sample (
1145- draws = draws ,
1146- chain = i ,
1147- start = start [i ],
1148- step = step ,
1149- trace = traces [i ],
1150- rng = rngs [i ],
1151- callback = callback ,
1152- ** kwargs ,
1153- )
1152+ progress_manager = ProgressBarManager (
1153+ step_method = step ,
1154+ chains = chains ,
1155+ draws = draws - kwargs .get ("tune" , 0 ),
1156+ tune = kwargs .get ("tune" , 0 ),
1157+ progressbar = kwargs .get ("progressbar" , True ),
1158+ progressbar_theme = kwargs .get ("progressbar_theme" , default_progress_theme ),
1159+ )
1160+
1161+ with progress_manager :
1162+ for i in range (chains ):
1163+ step .sampling_state = initial_step_state
1164+ _sample (
1165+ draws = draws ,
1166+ chain = i ,
1167+ start = start [i ],
1168+ step = step ,
1169+ trace = traces [i ],
1170+ rng = rngs [i ],
1171+ callback = callback ,
1172+ progress_manager = progress_manager ,
1173+ ** kwargs ,
1174+ )
11541175 return
11551176
11561177
11571178def _sample (
11581179 * ,
11591180 chain : int ,
1160- progressbar : bool ,
11611181 rng : np .random .Generator ,
11621182 start : PointType ,
11631183 draws : int ,
11641184 step : Step ,
11651185 trace : IBaseTrace ,
11661186 tune : int ,
11671187 model : Model | None = None ,
1168- progressbar_theme : Theme | None = default_progress_theme ,
11691188 callback = None ,
1189+ progress_manager : ProgressBarManager ,
11701190 ** kwargs ,
11711191) -> None :
11721192 """Sample one chain (singleprocess).
@@ -1177,27 +1197,23 @@ def _sample(
11771197 ----------
11781198 chain : int
11791199 Number of the chain that the samples will belong to.
1180- progressbar : bool
1181- Whether or not to display a progress bar in the command line. The bar shows the percentage
1182- of completion, the sampling speed in samples per second (SPS), and the estimated remaining
1183- time until completion ("expected time of arrival"; ETA).
1184- random_seed : single random seed
1200+ random_seed : Generator
1201+ Single random seed
11851202 start : dict
11861203 Starting point in parameter space (or partial point)
11871204 draws : int
11881205 The number of samples to draw
1189- step : function
1190- Step function
1206+ step : Step
1207+ Step class instance used to generate samples.
11911208 trace
11921209 A chain backend to record draws and stats.
11931210 tune : int
11941211 Number of iterations to tune.
1195- model : Model (optional if in ``with`` context)
1196- progressbar_theme : Theme
1197- Optional custom theme for the progress bar.
1212+ model : Model, optional
1213+ PyMC model. If None, the model is taken from the current context.
1214+ progress_manager: ProgressBarManager
1215+ Helper class used to handle progress bar styling and updates
11981216 """
1199- skip_first = kwargs .get ("skip_first" , 0 )
1200-
12011217 sampling_gen = _iter_sample (
12021218 draws = draws ,
12031219 step = step ,
@@ -1209,32 +1225,19 @@ def _sample(
12091225 rng = rng ,
12101226 callback = callback ,
12111227 )
1212- _pbar_data = {"chain" : chain , "divergences" : 0 }
1213- _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1214-
1215- progress = CustomProgress (
1216- "[progress.description]{task.description}" ,
1217- BarColumn (),
1218- "[progress.percentage]{task.percentage:>3.0f}%" ,
1219- TimeRemainingColumn (),
1220- TextColumn ("/" ),
1221- TimeElapsedColumn (),
1222- console = Console (theme = progressbar_theme ),
1223- disable = not progressbar ,
1224- )
1228+ try :
1229+ for it , stats in enumerate (sampling_gen ):
1230+ progress_manager .update (
1231+ chain_idx = chain , is_last = False , draw = it , stats = stats , tuning = it > tune
1232+ )
12251233
1226- with progress :
1227- try :
1228- task = progress .add_task (_desc .format (** _pbar_data ), completed = 0 , total = draws )
1229- for it , diverging in enumerate (sampling_gen ):
1230- if it >= skip_first and diverging :
1231- _pbar_data ["divergences" ] += 1
1232- progress .update (task , description = _desc .format (** _pbar_data ), completed = it )
1233- progress .update (
1234- task , description = _desc .format (** _pbar_data ), completed = draws , refresh = True
1234+ if not progress_manager .combined_progress or chain == progress_manager .chains - 1 :
1235+ progress_manager .update (
1236+ chain_idx = chain , is_last = True , draw = it , stats = stats , tuning = False
12351237 )
1236- except KeyboardInterrupt :
1237- pass
1238+
1239+ except KeyboardInterrupt :
1240+ pass
12381241
12391242
12401243def _iter_sample (
@@ -1248,7 +1251,7 @@ def _iter_sample(
12481251 rng : np .random .Generator ,
12491252 model : Model | None = None ,
12501253 callback : SamplingIteratorCallback | None = None ,
1251- ) -> Iterator [bool ]:
1254+ ) -> Iterator [list [ dict [ str , Any ]] ]:
12521255 """Sample one chain with a generator (singleprocess).
12531256
12541257 Parameters
@@ -1271,8 +1274,8 @@ def _iter_sample(
12711274
12721275 Yields
12731276 ------
1274- diverging : bool
1275- Indicates if the draw is divergent. Only available with some samplers.
1277+ stats : list of dict
1278+ Dictionary of statistics returned by step sampler
12761279 """
12771280 draws = int (draws )
12781281
@@ -1294,22 +1297,25 @@ def _iter_sample(
12941297 step .iter_count = 0
12951298 if i == tune :
12961299 step .stop_tuning ()
1300+
12971301 point , stats = step .step (point )
12981302 trace .record (point , stats )
12991303 log_warning_stats (stats )
1300- diverging = i > tune and len ( stats ) > 0 and ( stats [ 0 ]. get ( "diverging" ) is True )
1304+
13011305 if callback is not None :
13021306 callback (
13031307 trace = trace ,
13041308 draw = Draw (chain , i == draws , i , i < tune , stats , point ),
13051309 )
13061310
1307- yield diverging
1311+ yield stats
1312+
13081313 except (KeyboardInterrupt , BaseException ):
13091314 if isinstance (trace , ZarrChain ):
13101315 trace .record_sampling_state (step = step )
13111316 trace .close ()
13121317 raise
1318+
13131319 else :
13141320 if isinstance (trace , ZarrChain ):
13151321 trace .record_sampling_state (step = step )
0 commit comments