3636from arviz import InferenceData , dict_to_dataset
3737from arviz .data .base import make_attrs
3838from pytensor .graph .basic import Variable
39- from rich .console import Console
40- from rich .progress import BarColumn , TextColumn , TimeElapsedColumn , TimeRemainingColumn
4139from rich .theme import Theme
4240from threadpoolctl import threadpool_limits
4341from typing_extensions import Protocol
6765from pymc .step_methods .arraystep import BlockedStep , PopulationArrayStepShared
6866from pymc .step_methods .hmc import quadpotential
6967from pymc .util import (
70- CustomProgress ,
68+ ProgressManager ,
7169 RandomSeed ,
7270 RandomState ,
7371 _get_seeds_per_chain ,
@@ -1138,34 +1136,44 @@ def _sample_many(
11381136 Step function
11391137 """
11401138 initial_step_state = step .sampling_state
1141- for i in range (chains ):
1142- step .sampling_state = initial_step_state
1143- _sample (
1144- draws = draws ,
1145- chain = i ,
1146- start = start [i ],
1147- step = step ,
1148- trace = traces [i ],
1149- rng = rngs [i ],
1150- callback = callback ,
1151- ** kwargs ,
1152- )
1139+ progress_manager = ProgressManager (
1140+ step_method = step ,
1141+ chains = chains ,
1142+ draws = draws - kwargs .get ("tune" , 0 ),
1143+ tune = kwargs .get ("tune" , 0 ),
1144+ progressbar = kwargs .get ("progressbar" , True ),
1145+ progressbar_theme = kwargs .get ("progressbar_theme" , default_progress_theme ),
1146+ )
1147+
1148+ with progress_manager :
1149+ for i in range (chains ):
1150+ step .sampling_state = initial_step_state
1151+ _sample (
1152+ draws = draws ,
1153+ chain = i ,
1154+ start = start [i ],
1155+ step = step ,
1156+ trace = traces [i ],
1157+ rng = rngs [i ],
1158+ callback = callback ,
1159+ progress_manager = progress_manager ,
1160+ ** kwargs ,
1161+ )
11531162 return
11541163
11551164
11561165def _sample (
11571166 * ,
11581167 chain : int ,
1159- progressbar : bool ,
11601168 rng : np .random .Generator ,
11611169 start : PointType ,
11621170 draws : int ,
11631171 step : Step ,
11641172 trace : IBaseTrace ,
11651173 tune : int ,
11661174 model : Model | None = None ,
1167- progressbar_theme : Theme | None = default_progress_theme ,
11681175 callback = None ,
1176+ progress_manager : ProgressManager ,
11691177 ** kwargs ,
11701178) -> None :
11711179 """Sample one chain (singleprocess).
@@ -1176,27 +1184,23 @@ def _sample(
11761184 ----------
11771185 chain : int
11781186 Number of the chain that the samples will belong to.
1179- progressbar : bool
1180- Whether or not to display a progress bar in the command line. The bar shows the percentage
1181- of completion, the sampling speed in samples per second (SPS), and the estimated remaining
1182- time until completion ("expected time of arrival"; ETA).
1183- random_seed : single random seed
1187+ random_seed : Generator
1188+ Single random seed
11841189 start : dict
11851190 Starting point in parameter space (or partial point)
11861191 draws : int
11871192 The number of samples to draw
1188- step : function
1189- Step function
1193+ step : Step
1194+ Step class instance used to generate samples.
11901195 trace
11911196 A chain backend to record draws and stats.
11921197 tune : int
11931198 Number of iterations to tune.
1194- model : Model (optional if in ``with`` context)
1195- progressbar_theme : Theme
1196- Optional custom theme for the progress bar.
1199+ model : Model, optional
1200+ PyMC model. If None, the model is taken from the current context.
1201+ progress_manager: ProgressManager
1202+ Helper class used to handle progress bar styling and updates
11971203 """
1198- skip_first = kwargs .get ("skip_first" , 0 )
1199-
12001204 sampling_gen = _iter_sample (
12011205 draws = draws ,
12021206 step = step ,
@@ -1208,32 +1212,19 @@ def _sample(
12081212 rng = rng ,
12091213 callback = callback ,
12101214 )
1211- _pbar_data = {"chain" : chain , "divergences" : 0 }
1212- _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1213-
1214- progress = CustomProgress (
1215- "[progress.description]{task.description}" ,
1216- BarColumn (),
1217- "[progress.percentage]{task.percentage:>3.0f}%" ,
1218- TimeRemainingColumn (),
1219- TextColumn ("/" ),
1220- TimeElapsedColumn (),
1221- console = Console (theme = progressbar_theme ),
1222- disable = not progressbar ,
1223- )
1215+ try :
1216+ for it , stats in enumerate (sampling_gen ):
1217+ progress_manager .update (
1218+ chain_idx = chain , is_last = False , draw = it , stats = stats , tuning = it > tune
1219+ )
12241220
1225- with progress :
1226- try :
1227- task = progress .add_task (_desc .format (** _pbar_data ), completed = 0 , total = draws )
1228- for it , diverging in enumerate (sampling_gen ):
1229- if it >= skip_first and diverging :
1230- _pbar_data ["divergences" ] += 1
1231- progress .update (task , description = _desc .format (** _pbar_data ), completed = it )
1232- progress .update (
1233- task , description = _desc .format (** _pbar_data ), completed = draws , refresh = True
1221+ if not progress_manager .combined_progress or chain == progress_manager .chains - 1 :
1222+ progress_manager .update (
1223+ chain_idx = chain , is_last = True , draw = it , stats = stats , tuning = False
12341224 )
1235- except KeyboardInterrupt :
1236- pass
1225+
1226+ except KeyboardInterrupt :
1227+ pass
12371228
12381229
12391230def _iter_sample (
0 commit comments