Skip to content

Commit 23d122f

Browse files
Step samplers are responsible for setting up progress bars
1 parent 06572c6 commit 23d122f

File tree

6 files changed

+188
-58
lines changed

6 files changed

+188
-58
lines changed

pymc/sampling/parallel.py

Lines changed: 23 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,20 @@
2727
import cloudpickle
2828
import numpy as np
2929

30-
from rich.console import Console
31-
from rich.progress import TextColumn
32-
from rich.style import Style
33-
from rich.table import Column
3430
from rich.theme import Theme
3531
from threadpoolctl import threadpool_limits
3632

3733
from pymc.backends.zarr import ZarrChain
3834
from pymc.blocking import DictToArrayBijection
3935
from pymc.exceptions import SamplingError
4036
from pymc.util import (
41-
CustomProgress,
42-
DivergenceBarColumn,
43-
RandomGeneratorState,
37+
compute_draw_speed,
38+
create_progress_bar,
4439
default_progress_theme,
40+
RandomGeneratorState,
4541
get_state_from_generator,
4642
random_generator_from_state,
43+
4744
)
4845

4946
logger = logging.getLogger(__name__)
@@ -489,33 +486,21 @@ def __init__(
489486

490487
self._in_context = False
491488

492-
self._progress = CustomProgress(
493-
DivergenceBarColumn(
494-
table_column=Column("Progress", ratio=2),
495-
diverging_color="tab:red",
496-
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
497-
finished_style=Style.parse("rgb(44,160,44)"), # tab:green
498-
),
499-
TextColumn("{task.fields[draws]:,d}", table_column=Column("Draws", ratio=1)),
500-
TextColumn(
501-
"{task.fields[divergences]:,d}", table_column=Column("Divergences", ratio=1)
502-
),
503-
TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)),
504-
TextColumn("{task.fields[tree_depth]:,d}", table_column=Column("Tree depth", ratio=1)),
505-
TextColumn(
506-
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
507-
table_column=Column("Sampling Speed", ratio=1),
508-
),
509-
console=Console(theme=progressbar_theme),
510-
disable=not progressbar,
511-
include_headers=True,
489+
progress_columns, progress_stats = step_method._progressbar_config(chains)
490+
491+
self._progress = create_progress_bar(
492+
progress_columns,
493+
progress_stats,
494+
progressbar=progressbar,
495+
progressbar_theme=progressbar_theme,
512496
)
513497

498+
self.progress_stats = progress_stats
499+
self.update_stats = step_method._make_update_stat_function()
500+
514501
self._show_progress = progressbar
515502
self._divergences = 0
516-
self._divergences_by_chain = [0] * chains
517503
self._completed_draws = 0
518-
self._completed_draws_by_chain = [0] * chains
519504
self._total_draws = draws + tune
520505
self._desc = "Sampling chain"
521506
self._chains = chains
@@ -537,66 +522,48 @@ def __iter__(self):
537522
tasks = [
538523
progress.add_task(
539524
self._desc.format(self),
540-
completed=self._completed_draws,
525+
completed=0,
526+
draws=0,
541527
total=self._total_draws,
542528
chain_idx=chain_idx,
543-
draws=0,
544-
divergences=0,
545-
step_size=0.0,
546-
tree_depth=0,
547529
sampling_speed=0,
548530
speed_unit="draws/s",
531+
**{stat: value[chain_idx] for stat, value in self.progress_stats.items()},
549532
)
550533
for chain_idx in range(self._chains)
551534
]
552535

553536
while self._active:
554537
draw = ProcessAdapter.recv_draw(self._active)
555538
proc, is_last, draw, tuning, stats = draw
556-
speed = 0
557-
unit = "draws/s"
558539

559540
self._completed_draws += 1
560-
self._completed_draws_by_chain[proc.chain] += 1
541+
542+
speed, unit = compute_draw_speed(progress._tasks[proc.chain].elapsed, draw)
561543

562544
if not tuning and stats and stats[0].get("diverging"):
563545
self._divergences += 1
564-
self._divergences_by_chain[proc.chain] += 1
565-
566-
if self._show_progress:
567-
elapsed = max(progress._tasks[proc.chain].elapsed, 1e-4)
568-
speed = self._completed_draws_by_chain[proc.chain] / elapsed
569546

570-
if speed > 1:
571-
unit = "draws/s"
572-
else:
573-
unit = "s/draws"
574-
speed = 1 / speed
547+
self.progress_stats = self.update_stats(self.progress_stats, stats, proc.chain)
575548

576549
progress.update(
577550
tasks[proc.chain],
578-
completed=self._completed_draws_by_chain[proc.chain],
551+
completed=draw,
579552
draws=draw,
580-
divergences=self._divergences_by_chain[proc.chain],
581-
step_size=stats[0].get("step_size", 0),
582-
tree_depth=stats[0].get("tree_size", 0),
583553
sampling_speed=speed,
584554
speed_unit=unit,
555+
**{stat: value[proc.chain] for stat, value in self.progress_stats.items()},
585556
)
586557

587558
if is_last:
588-
self._completed_draws_by_chain[proc.chain] += 1
589-
590559
proc.join()
591560
self._active.remove(proc)
592561
self._finished.append(proc)
593562
self._make_active()
594563
progress.update(
595564
tasks[proc.chain],
596565
draws=draw + 1,
597-
divergences=self._divergences_by_chain[proc.chain],
598-
step_size=stats[0].get("step_size", 0),
599-
tree_depth=stats[0].get("tree_size", 0),
566+
**{stat: value[proc.chain] for stat, value in self.progress_stats.items()},
600567
refresh=True,
601568
)
602569

pymc/step_methods/compound.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,38 @@ def set_rng(self, rng: RandomGenerator):
297297
for method, _rng in zip(self.methods, _rngs):
298298
method.set_rng(_rng)
299299

300+
def _progressbar_config(self, n_chains=1):
301+
from functools import reduce
302+
303+
column_lists, stat_dict_list = zip(
304+
*[method._progressbar_config(n_chains) for method in self.methods]
305+
)
306+
flat_list = reduce(lambda left_list, right_list: left_list + right_list, column_lists)
307+
308+
columns = []
309+
headers = []
310+
311+
for col in flat_list:
312+
name = col.get_table_column().header
313+
if name not in headers:
314+
headers.append(name)
315+
columns.append(col)
316+
317+
stats = reduce(lambda left_dict, right_dict: left_dict | right_dict, stat_dict_list)
318+
319+
return columns, stats
320+
321+
def _make_update_stat_function(self):
322+
update_fns = [method._make_update_stats_function() for method in self.methods]
323+
324+
def update_stats(stats, step_stats, chain_idx):
325+
for step_stat, update_fn in zip(step_stats, update_fns):
326+
stats = update_fn(stats, step_stat, chain_idx)
327+
328+
return stats
329+
330+
return update_stats
331+
300332

301333
def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]:
302334
"""Flatten a hierarchy of step methods to a list."""

pymc/step_methods/hmc/nuts.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import numpy as np
2121

2222
from pytensor import config
23+
from rich.progress import TextColumn
24+
from rich.table import Column
2325

2426
from pymc.stats.convergence import SamplerWarning
2527
from pymc.step_methods.compound import Competence
@@ -229,6 +231,35 @@ def competence(var, has_grad):
229231
return Competence.PREFERRED
230232
return Competence.INCOMPATIBLE
231233

234+
@staticmethod
235+
def _progressbar_config(n_chains=1):
236+
columns = [
237+
TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)),
238+
TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)),
239+
TextColumn("{task.fields[tree_size]}", table_column=Column("Grad evals", ratio=1)),
240+
]
241+
242+
stats = {
243+
"divergences": [0] * n_chains,
244+
"step_size": [0] * n_chains,
245+
"tree_size": [0] * n_chains,
246+
}
247+
248+
return columns, stats
249+
250+
@staticmethod
251+
def _make_update_stat_function():
252+
def update_stats(stats, step_stats, chain_idx):
253+
if isinstance(step_stats, list):
254+
step_stats = step_stats[0]
255+
256+
stats["divergences"][chain_idx] += step_stats["diverging"]
257+
stats["step_size"][chain_idx] = step_stats["step_size"]
258+
stats["tree_size"][chain_idx] = step_stats["tree_size"]
259+
return stats
260+
261+
return update_stats
262+
232263

233264
# A proposal for the next position
234265
Proposal = namedtuple("Proposal", "q, q_grad, energy, logp, index_in_trajectory")

pymc/step_methods/metropolis.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from pytensor import tensor as pt
2525
from pytensor.graph.fg import MissingInputError
2626
from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV
27+
from rich.progress import TextColumn
28+
from rich.table import Column
2729

2830
import pymc as pm
2931

@@ -325,6 +327,38 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
325327
def competence(var, has_grad):
326328
return Competence.COMPATIBLE
327329

330+
@staticmethod
331+
def _progressbar_config(n_chains=1):
332+
columns = [
333+
TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)),
334+
TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)),
335+
TextColumn(
336+
"{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1)
337+
),
338+
]
339+
340+
stats = {
341+
"tune": [True] * n_chains,
342+
"scaling": [0] * n_chains,
343+
"accept_rate": [0.0] * n_chains,
344+
}
345+
346+
return columns, stats
347+
348+
@staticmethod
349+
def _make_update_stats_function():
350+
def update_stats(stats, step_stats, chain_idx):
351+
if isinstance(step_stats, list):
352+
step_stats = step_stats[0]
353+
354+
stats["tune"][chain_idx] = step_stats["tune"]
355+
stats["accept_rate"][chain_idx] = step_stats["accept"]
356+
stats["scaling"][chain_idx] = step_stats["scaling"]
357+
358+
return stats
359+
360+
return update_stats
361+
328362

329363
def tune(scale, acc_rate):
330364
"""

pymc/step_methods/slicer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
import numpy as np
1919

20+
from rich.progress import TextColumn
21+
from rich.table import Column
22+
2023
from pymc.blocking import RaveledVars, StatsType
2124
from pymc.initial_point import PointType
2225
from pymc.model import modelcontext
@@ -195,3 +198,29 @@ def competence(var, has_grad):
195198
return Competence.PREFERRED
196199
return Competence.COMPATIBLE
197200
return Competence.INCOMPATIBLE
201+
202+
@staticmethod
203+
def _progressbar_config(n_chains=1):
204+
columns = [
205+
TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)),
206+
TextColumn("{task.fields[nstep_out]}", table_column=Column("Steps out", ratio=1)),
207+
TextColumn("{task.fields[nstep_in]}", table_column=Column("Steps in", ratio=1)),
208+
]
209+
210+
stats = {"tune": [True] * n_chains, "nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains}
211+
212+
return columns, stats
213+
214+
@staticmethod
215+
def _make_update_stats_function():
216+
def update_stats(stats, step_stats, chain_idx):
217+
if isinstance(step_stats, list):
218+
step_stats = step_stats[0]
219+
220+
stats["tune"][chain_idx] = step_stats["tune"]
221+
stats["nstep_out"][chain_idx] = step_stats["nstep_out"]
222+
stats["nstep_in"][chain_idx] = step_stats["nstep_in"]
223+
224+
return stats
225+
226+
return update_stats

pymc/util.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
from pytensor.compile import SharedVariable
3232
from pytensor.graph.utils import ValidatingScratchpad
3333
from rich.box import SIMPLE_HEAD
34-
from rich.progress import BarColumn, Progress, Task
34+
from rich.console import Console
35+
from rich.progress import BarColumn, Progress, Task, TextColumn
3536
from rich.style import Style
3637
from rich.table import Column, Table
3738
from rich.theme import Theme
@@ -684,14 +685,50 @@ def __init__(self, *args, diverging_color="red", **kwargs):
684685

685686
def callbacks(self, task: "Task"):
686687
divergences = task.fields.get("divergences", 0)
687-
if divergences > 0:
688+
if isinstance(divergences, float | int) and divergences > 0:
688689
self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb))
689690
self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb))
690691
else:
691692
self.complete_style = self.non_diverging_style
692693
self.finished_style = self.non_diverging_finished_style
693694

694695

696+
def create_progress_bar(step_columns, init_stat_dict, progressbar, progressbar_theme):
697+
columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))]
698+
columns += step_columns
699+
columns += [
700+
TextColumn(
701+
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
702+
table_column=Column("Sampling Speed", ratio=1),
703+
)
704+
]
705+
706+
return CustomProgress(
707+
DivergenceBarColumn(
708+
table_column=Column("Progress", ratio=2),
709+
diverging_color="tab:red",
710+
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
711+
finished_style=Style.parse("rgb(44,160,44)"), # tab:green
712+
),
713+
*columns,
714+
console=Console(theme=progressbar_theme),
715+
disable=not progressbar,
716+
include_headers=True,
717+
)
718+
719+
720+
def compute_draw_speed(elapsed, draws):
721+
speed = draws / max(elapsed, 1e-6)
722+
723+
if speed > 1 or speed == 0:
724+
unit = "draws/s"
725+
else:
726+
unit = "s/draws"
727+
speed = 1 / speed
728+
729+
return speed, unit
730+
731+
695732
RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"])
696733

697734

0 commit comments

Comments
 (0)