Skip to content

Commit f849206

Browse files
One progress bar per chain when samplings
1 parent fa43eba commit f849206

File tree

2 files changed

+168
-23
lines changed

2 files changed

+168
-23
lines changed

pymc/sampling/parallel.py

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
import numpy as np
2929

3030
from rich.console import Console
31-
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
31+
from rich.progress import TextColumn
32+
from rich.style import Style
33+
from rich.table import Column
3234
from rich.theme import Theme
3335
from threadpoolctl import threadpool_limits
3436

@@ -37,6 +39,7 @@
3739
from pymc.exceptions import SamplingError
3840
from pymc.util import (
3941
CustomProgress,
42+
DivergenceBarColumn,
4043
RandomGeneratorState,
4144
default_progress_theme,
4245
get_state_from_generator,
@@ -487,20 +490,35 @@ def __init__(
487490
self._in_context = False
488491

489492
self._progress = CustomProgress(
490-
"[progress.description]{task.description}",
491-
BarColumn(),
492-
"[progress.percentage]{task.percentage:>3.0f}%",
493-
TimeRemainingColumn(),
494-
TextColumn("/"),
495-
TimeElapsedColumn(),
493+
DivergenceBarColumn(
494+
table_column=Column("Progress", ratio=2),
495+
diverging_color="tab:red",
496+
diverging_finished_color="tab:purple",
497+
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
498+
finished_style=Style.parse("rgb(44,160,44)"), # tab:green
499+
),
500+
TextColumn("{task.fields[draws]:,d}", table_column=Column("Draws", ratio=1)),
501+
TextColumn(
502+
"{task.fields[divergences]:,d}", table_column=Column("Divergences", ratio=1)
503+
),
504+
TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)),
505+
TextColumn("{task.fields[tree_depth]:,d}", table_column=Column("Tree depth", ratio=1)),
506+
TextColumn(
507+
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
508+
table_column=Column("Sampling Speed", ratio=1),
509+
),
496510
console=Console(theme=progressbar_theme),
497511
disable=not progressbar,
512+
include_headers=True,
498513
)
514+
499515
self._show_progress = progressbar
500516
self._divergences = 0
517+
self._divergences_by_chain = [0] * chains
501518
self._completed_draws = 0
502-
self._total_draws = chains * (draws + tune)
503-
self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences"
519+
self._completed_draws_by_chain = [0] * chains
520+
self._total_draws = draws + tune
521+
self._desc = "Sampling chain"
504522
self._chains = chains
505523

506524
def _make_active(self):
@@ -517,31 +535,71 @@ def __iter__(self):
517535
self._make_active()
518536

519537
with self._progress as progress:
520-
task = progress.add_task(
521-
self._desc.format(self),
522-
completed=self._completed_draws,
523-
total=self._total_draws,
524-
)
538+
tasks = [
539+
progress.add_task(
540+
self._desc.format(self),
541+
completed=self._completed_draws,
542+
total=self._total_draws,
543+
chain_idx=chain_idx,
544+
draws=0,
545+
divergences=0,
546+
step_size=0.0,
547+
tree_depth=0,
548+
sampling_speed=0,
549+
speed_unit="draws/s",
550+
)
551+
for chain_idx in range(self._chains)
552+
]
525553

526554
while self._active:
527555
draw = ProcessAdapter.recv_draw(self._active)
528556
proc, is_last, draw, tuning, stats = draw
557+
speed = 0
558+
unit = "draws/s"
559+
529560
self._completed_draws += 1
561+
self._completed_draws_by_chain[proc.chain] += 1
562+
530563
if not tuning and stats and stats[0].get("diverging"):
531564
self._divergences += 1
565+
self._divergences_by_chain[proc.chain] += 1
566+
567+
if self._show_progress:
568+
elapsed = progress._tasks[proc.chain].elapsed
569+
speed = self._completed_draws_by_chain[proc.chain] / elapsed
570+
571+
if speed > 1:
572+
unit = "draws/s"
573+
else:
574+
unit = "s/draws"
575+
speed = 1 / speed
576+
532577
progress.update(
533-
task,
534-
completed=self._completed_draws,
535-
total=self._total_draws,
536-
description=self._desc.format(self),
578+
tasks[proc.chain],
579+
completed=self._completed_draws_by_chain[proc.chain],
580+
draws=draw,
581+
divergences=self._divergences_by_chain[proc.chain],
582+
step_size=stats[0].get("step_size", 0),
583+
tree_depth=stats[0].get("tree_size", 0),
584+
sampling_speed=speed,
585+
speed_unit=unit,
537586
)
538587

539588
if is_last:
589+
self._completed_draws_by_chain[proc.chain] += 1
590+
540591
proc.join()
541592
self._active.remove(proc)
542593
self._finished.append(proc)
543594
self._make_active()
544-
progress.update(task, description=self._desc.format(self), refresh=True)
595+
progress.update(
596+
tasks[proc.chain],
597+
draws=draw + 1,
598+
divergences=self._divergences_by_chain[proc.chain],
599+
step_size=stats[0].get("step_size", 0),
600+
tree_depth=stats[0].get("tree_size", 0),
601+
refresh=True,
602+
)
545603

546604
# We could also yield proc.shared_point_view directly,
547605
# and only call proc.write_next() after the yield returns.

pymc/util.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818

1919
from collections import namedtuple
20-
from collections.abc import Sequence
20+
from collections.abc import Iterable, Sequence
2121
from copy import deepcopy
2222
from typing import NewType, cast
2323

@@ -30,7 +30,10 @@
3030
from pytensor import Variable
3131
from pytensor.compile import SharedVariable
3232
from pytensor.graph.utils import ValidatingScratchpad
33-
from rich.progress import Progress
33+
from rich.box import SIMPLE_HEAD
34+
from rich.progress import BarColumn, Progress, Task
35+
from rich.style import Style
36+
from rich.table import Column, Table
3437
from rich.theme import Theme
3538

3639
from pymc.exceptions import BlockModelAccessError
@@ -556,8 +559,10 @@ class CustomProgress(Progress):
556559
it's `True`.
557560
"""
558561

559-
def __init__(self, *args, **kwargs):
560-
self.is_enabled = kwargs.get("disable", None) is not True
562+
def __init__(self, *args, disable=False, include_headers=False, **kwargs):
563+
self.is_enabled = not disable
564+
self.include_headers = include_headers
565+
561566
if self.is_enabled:
562567
super().__init__(*args, **kwargs)
563568

@@ -607,6 +612,88 @@ def update(
607612
)
608613
return None
609614

615+
def make_tasks_table(self, tasks: Iterable[Task]) -> Table:
616+
"""Get a table to render the Progress display.
617+
618+
Unlike the parent method, this one returns a full table (not a grid), allowing for column headings.
619+
620+
Parameters
621+
----------
622+
tasks: Iterable[Task]
623+
An iterable of Task instances, one per row of the table.
624+
625+
Returns
626+
-------
627+
table: Table
628+
A table instance.
629+
"""
630+
631+
def call_column(column, task):
632+
if hasattr(column, "callbacks"):
633+
column.callbacks(task)
634+
635+
return column(task)
636+
637+
table_columns = (
638+
(
639+
Column(no_wrap=True)
640+
if isinstance(_column, str)
641+
else _column.get_table_column().copy()
642+
)
643+
for _column in self.columns
644+
)
645+
if self.include_headers:
646+
table = Table(
647+
*table_columns,
648+
padding=(0, 1),
649+
expand=self.expand,
650+
show_header=True,
651+
show_edge=True,
652+
box=SIMPLE_HEAD,
653+
)
654+
else:
655+
table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand)
656+
657+
for task in tasks:
658+
if task.visible:
659+
table.add_row(
660+
*(
661+
(
662+
column.format(task=task)
663+
if isinstance(column, str)
664+
else call_column(column, task)
665+
)
666+
for column in self.columns
667+
)
668+
)
669+
670+
return table
671+
672+
673+
class DivergenceBarColumn(BarColumn):
674+
def __init__(self, *args, diverging_color="red", diverging_finished_color="purple", **kwargs):
675+
from matplotlib.colors import to_rgb
676+
677+
self.diverging_color = diverging_color
678+
self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)]
679+
680+
self.diverging_finished_color = diverging_finished_color
681+
self.diverging_finished_rgb = [int(x * 255) for x in to_rgb(self.diverging_finished_color)]
682+
683+
super().__init__(*args, **kwargs)
684+
685+
self.non_diverging_style = self.complete_style
686+
self.non_diverging_finished_style = self.finished_style
687+
688+
def callbacks(self, task: "Task"):
689+
divergences = task.fields.get("divergences", 0)
690+
if divergences > 0:
691+
self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb))
692+
self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_finished_rgb))
693+
else:
694+
self.complete_style = self.non_diverging_style
695+
self.finished_style = self.non_diverging_finished_style
696+
610697

611698
RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"])
612699

0 commit comments

Comments
 (0)