Skip to content

Commit 28a80c1

Browse files
Use ProgressManager in _sample_many
1 parent 1e13cf9 commit 28a80c1

File tree

1 file changed

+44
-53
lines changed

1 file changed

+44
-53
lines changed

pymc/sampling/mcmc.py

Lines changed: 44 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
from arviz import InferenceData, dict_to_dataset
3737
from arviz.data.base import make_attrs
3838
from pytensor.graph.basic import Variable
39-
from rich.console import Console
40-
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
4139
from rich.theme import Theme
4240
from threadpoolctl import threadpool_limits
4341
from typing_extensions import Protocol
@@ -67,7 +65,7 @@
6765
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
6866
from pymc.step_methods.hmc import quadpotential
6967
from pymc.util import (
70-
CustomProgress,
68+
ProgressManager,
7169
RandomSeed,
7270
RandomState,
7371
_get_seeds_per_chain,
@@ -1138,34 +1136,44 @@ def _sample_many(
11381136
Step function
11391137
"""
11401138
initial_step_state = step.sampling_state
1141-
for i in range(chains):
1142-
step.sampling_state = initial_step_state
1143-
_sample(
1144-
draws=draws,
1145-
chain=i,
1146-
start=start[i],
1147-
step=step,
1148-
trace=traces[i],
1149-
rng=rngs[i],
1150-
callback=callback,
1151-
**kwargs,
1152-
)
1139+
progress_manager = ProgressManager(
1140+
step_method=step,
1141+
chains=chains,
1142+
draws=draws - kwargs.get("tune", 0),
1143+
tune=kwargs.get("tune", 0),
1144+
progressbar=kwargs.get("progressbar", True),
1145+
progressbar_theme=kwargs.get("progressbar_theme", default_progress_theme),
1146+
)
1147+
1148+
with progress_manager:
1149+
for i in range(chains):
1150+
step.sampling_state = initial_step_state
1151+
_sample(
1152+
draws=draws,
1153+
chain=i,
1154+
start=start[i],
1155+
step=step,
1156+
trace=traces[i],
1157+
rng=rngs[i],
1158+
callback=callback,
1159+
progress_manager=progress_manager,
1160+
**kwargs,
1161+
)
11531162
return
11541163

11551164

11561165
def _sample(
11571166
*,
11581167
chain: int,
1159-
progressbar: bool,
11601168
rng: np.random.Generator,
11611169
start: PointType,
11621170
draws: int,
11631171
step: Step,
11641172
trace: IBaseTrace,
11651173
tune: int,
11661174
model: Model | None = None,
1167-
progressbar_theme: Theme | None = default_progress_theme,
11681175
callback=None,
1176+
progress_manager: ProgressManager,
11691177
**kwargs,
11701178
) -> None:
11711179
"""Sample one chain (singleprocess).
@@ -1176,27 +1184,23 @@ def _sample(
11761184
----------
11771185
chain : int
11781186
Number of the chain that the samples will belong to.
1179-
progressbar : bool
1180-
Whether or not to display a progress bar in the command line. The bar shows the percentage
1181-
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
1182-
time until completion ("expected time of arrival"; ETA).
1183-
random_seed : single random seed
1187+
random_seed : Generator
1188+
Single random seed
11841189
start : dict
11851190
Starting point in parameter space (or partial point)
11861191
draws : int
11871192
The number of samples to draw
1188-
step : function
1189-
Step function
1193+
step : Step
1194+
Step class instance used to generate samples.
11901195
trace
11911196
A chain backend to record draws and stats.
11921197
tune : int
11931198
Number of iterations to tune.
1194-
model : Model (optional if in ``with`` context)
1195-
progressbar_theme : Theme
1196-
Optional custom theme for the progress bar.
1199+
model : Model, optional
1200+
PyMC model. If None, the model is taken from the current context.
1201+
progress_manager: ProgressManager
1202+
Helper class used to handle progress bar styling and updates
11971203
"""
1198-
skip_first = kwargs.get("skip_first", 0)
1199-
12001204
sampling_gen = _iter_sample(
12011205
draws=draws,
12021206
step=step,
@@ -1208,32 +1212,19 @@ def _sample(
12081212
rng=rng,
12091213
callback=callback,
12101214
)
1211-
_pbar_data = {"chain": chain, "divergences": 0}
1212-
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1213-
1214-
progress = CustomProgress(
1215-
"[progress.description]{task.description}",
1216-
BarColumn(),
1217-
"[progress.percentage]{task.percentage:>3.0f}%",
1218-
TimeRemainingColumn(),
1219-
TextColumn("/"),
1220-
TimeElapsedColumn(),
1221-
console=Console(theme=progressbar_theme),
1222-
disable=not progressbar,
1223-
)
1215+
try:
1216+
for it, stats in enumerate(sampling_gen):
1217+
progress_manager.update(
1218+
chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune
1219+
)
12241220

1225-
with progress:
1226-
try:
1227-
task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws)
1228-
for it, diverging in enumerate(sampling_gen):
1229-
if it >= skip_first and diverging:
1230-
_pbar_data["divergences"] += 1
1231-
progress.update(task, description=_desc.format(**_pbar_data), completed=it)
1232-
progress.update(
1233-
task, description=_desc.format(**_pbar_data), completed=draws, refresh=True
1221+
if not progress_manager.combined_progress or chain == progress_manager.chains - 1:
1222+
progress_manager.update(
1223+
chain_idx=chain, is_last=True, draw=it, stats=stats, tuning=False
12341224
)
1235-
except KeyboardInterrupt:
1236-
pass
1225+
1226+
except KeyboardInterrupt:
1227+
pass
12371228

12381229

12391230
def _iter_sample(

0 commit comments

Comments
 (0)