Skip to content

Commit d61ddf6

Browse files
Create ProgressManager class to handle progress bars
1 parent a96d7bb commit d61ddf6

File tree

2 files changed

+167
-79
lines changed

2 files changed

+167
-79
lines changed

pymc/sampling/parallel.py

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@
3434
from pymc.blocking import DictToArrayBijection
3535
from pymc.exceptions import SamplingError
3636
from pymc.util import (
37+
ProgressManager,
3738
RandomGeneratorState,
38-
compute_draw_speed,
39-
create_progress_bar,
4039
default_progress_theme,
4140
get_state_from_generator,
4241
random_generator_from_state,
@@ -484,26 +483,15 @@ def __init__(
484483
self._max_active = cores
485484

486485
self._in_context = False
487-
488-
progress_columns, progress_stats = step_method._progressbar_config(chains)
489-
490-
self._progress = create_progress_bar(
491-
progress_columns,
492-
progress_stats,
486+
self._progress = ProgressManager(
487+
step_method=step_method,
488+
chains=chains,
489+
draws=draws,
490+
tune=tune,
493491
progressbar=progressbar,
494492
progressbar_theme=progressbar_theme,
495493
)
496494

497-
self.progress_stats = progress_stats
498-
self.update_stats = step_method._make_update_stats_function()
499-
500-
self._show_progress = progressbar
501-
self._divergences = 0
502-
self._completed_draws = 0
503-
self._total_draws = draws + tune
504-
self._desc = "Sampling chain"
505-
self._chains = chains
506-
507495
def _make_active(self):
508496
while self._inactive and len(self._active) < self._max_active:
509497
proc = self._inactive.pop(0)
@@ -517,54 +505,20 @@ def __iter__(self):
517505
raise ValueError("Use ParallelSampler as context manager.")
518506
self._make_active()
519507

520-
with self._progress as progress:
521-
tasks = [
522-
progress.add_task(
523-
self._desc.format(self),
524-
completed=0,
525-
draws=0,
526-
total=self._total_draws - 1,
527-
chain_idx=chain_idx,
528-
sampling_speed=0,
529-
speed_unit="draws/s",
530-
**{stat: value[chain_idx] for stat, value in self.progress_stats.items()},
531-
)
532-
for chain_idx in range(self._chains)
533-
]
534-
508+
with self._progress:
535509
while self._active:
536510
draw = ProcessAdapter.recv_draw(self._active)
537511
proc, is_last, draw, tuning, stats = draw
538512

539-
self._completed_draws += 1
540-
541-
speed, unit = compute_draw_speed(progress._tasks[proc.chain].elapsed, draw)
542-
543-
if not tuning and stats and stats[0].get("diverging"):
544-
self._divergences += 1
545-
546-
self.progress_stats = self.update_stats(self.progress_stats, stats, proc.chain)
547-
548-
progress.update(
549-
tasks[proc.chain],
550-
completed=draw,
551-
draws=draw,
552-
sampling_speed=speed,
553-
speed_unit=unit,
554-
**{stat: value[proc.chain] for stat, value in self.progress_stats.items()},
513+
self._progress.update(
514+
chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats
555515
)
556516

557517
if is_last:
558518
proc.join()
559519
self._active.remove(proc)
560520
self._finished.append(proc)
561521
self._make_active()
562-
progress.update(
563-
tasks[proc.chain],
564-
draws=draw + 1,
565-
**{stat: value[proc.chain] for stat, value in self.progress_stats.items()},
566-
refresh=True,
567-
)
568522

569523
# We could also yield proc.shared_point_view directly,
570524
# and only call proc.write_next() after the yield returns.

pymc/util.py

Lines changed: 158 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -702,30 +702,164 @@ def callbacks(self, task: "Task"):
702702
self.finished_style = self.non_diverging_finished_style
703703

704704

705-
def create_progress_bar(step_columns, init_stat_dict, progressbar, progressbar_theme):
706-
columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))]
707-
columns += step_columns
708-
columns += [
709-
TextColumn(
710-
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
711-
table_column=Column("Sampling Speed", ratio=1),
712-
),
713-
TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)),
714-
TimeRemainingColumn(table_column=Column("Remaining", ratio=1)),
715-
]
716-
717-
return CustomProgress(
718-
DivergenceBarColumn(
719-
table_column=Column("Progress", ratio=2),
720-
diverging_color="tab:red",
721-
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
722-
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue
723-
),
724-
*columns,
725-
console=Console(theme=progressbar_theme),
726-
disable=not progressbar,
727-
include_headers=True,
728-
)
705+
class ProgressManager:
706+
def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_theme):
707+
mode = "chain"
708+
stats = "full"
709+
710+
if isinstance(progressbar, bool):
711+
show_progress = progressbar
712+
else:
713+
show_progress = True
714+
715+
if "+" in progressbar:
716+
mode, stats = progressbar.split("+")
717+
else:
718+
mode = progressbar
719+
stats = "full"
720+
721+
if mode not in ["chain", "combined"]:
722+
raise ValueError('Invalid mode. Valid values are "chain" and "combined"')
723+
if stats not in ["full", "simple"]:
724+
raise ValueError('Invalid stats. Valid values are "full" and "simple"')
725+
726+
progress_columns, progress_stats = step_method._progressbar_config(chains)
727+
self.combined_progress = mode == "combined"
728+
self.full_stats = stats == "full"
729+
730+
self._progress = self.create_progress_bar(
731+
progress_columns,
732+
progressbar=progressbar,
733+
progressbar_theme=progressbar_theme,
734+
)
735+
736+
self.progress_stats = progress_stats
737+
self.update_stats = step_method._make_update_stats_function()
738+
739+
self._show_progress = show_progress
740+
self.divergences = 0
741+
self.completed_draws = 0
742+
self.total_draws = draws + tune
743+
self.desc = "Sampling chain"
744+
self.chains = chains
745+
746+
self._tasks: list[Task] | None = None
747+
748+
def __enter__(self):
749+
self._initialize_tasks()
750+
751+
return self._progress.__enter__()
752+
753+
def __exit__(self, exc_type, exc_val, exc_tb):
754+
return self._progress.__exit__(exc_type, exc_val, exc_tb)
755+
756+
def _initialize_tasks(self):
757+
if self.combined_progress:
758+
self.tasks = [
759+
self._progress.add_task(
760+
self.desc.format(self),
761+
completed=0,
762+
draws=0,
763+
total=self.total_draws * self.chains - 1,
764+
chain_idx=0,
765+
sampling_speed=0,
766+
speed_unit="draws/s",
767+
**{stat: value[0] for stat, value in self.progress_stats.items()},
768+
)
769+
]
770+
771+
else:
772+
self.tasks = [
773+
self._progress.add_task(
774+
self.desc.format(self),
775+
completed=0,
776+
draws=0,
777+
total=self.total_draws - 1,
778+
chain_idx=chain_idx,
779+
sampling_speed=0,
780+
speed_unit="draws/s",
781+
**{stat: value[chain_idx] for stat, value in self.progress_stats.items()},
782+
)
783+
for chain_idx in range(self.chains)
784+
]
785+
786+
def compute_draw_speed(self, chain_idx, draws):
787+
elapsed = self._progress.tasks[chain_idx].elapsed
788+
speed = draws / max(elapsed, 1e-6)
789+
790+
if speed > 1 or speed == 0:
791+
unit = "draws/s"
792+
else:
793+
unit = "s/draws"
794+
speed = 1 / speed
795+
796+
return speed, unit
797+
798+
def update(self, chain_idx, is_last, draw, tuning, stats):
799+
if not self._show_progress:
800+
return
801+
802+
self.completed_draws += 1
803+
if self.combined_progress:
804+
draw = self.completed_draws
805+
chain_idx = 0
806+
807+
speed, unit = self.compute_draw_speed(chain_idx, draw)
808+
809+
if not tuning and stats and stats[0].get("diverging"):
810+
self.divergences += 1
811+
812+
self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
813+
more_updates = (
814+
{stat: value[chain_idx] for stat, value in self.progress_stats.items()}
815+
if self.full_stats
816+
else {}
817+
)
818+
819+
self._progress.update(
820+
self.tasks[chain_idx],
821+
completed=draw,
822+
draws=draw,
823+
sampling_speed=speed,
824+
speed_unit=unit,
825+
**more_updates,
826+
)
827+
828+
if is_last:
829+
self._progress.update(
830+
self.tasks[chain_idx],
831+
draws=draw + 1 if not self.combined_progress else draw - 1,
832+
**more_updates,
833+
refresh=True,
834+
)
835+
836+
def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
837+
columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))]
838+
839+
if self.full_stats:
840+
columns += step_columns
841+
842+
columns += [
843+
TextColumn(
844+
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
845+
table_column=Column("Sampling Speed", ratio=1),
846+
),
847+
TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)),
848+
TimeRemainingColumn(table_column=Column("Remaining", ratio=1)),
849+
]
850+
851+
return CustomProgress(
852+
DivergenceBarColumn(
853+
table_column=Column("Progress", ratio=2),
854+
diverging_color="tab:red",
855+
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
856+
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue
857+
),
858+
*columns,
859+
console=Console(theme=progressbar_theme),
860+
disable=not progressbar,
861+
include_headers=True,
862+
)
729863

730864

731865
def compute_draw_speed(elapsed, draws):

0 commit comments

Comments
 (0)