Skip to content

Commit 79d1248

Browse files
Simplify progressbar choices, update docstring
1 parent 9de9930 commit 79d1248

File tree

2 files changed

+39
-51
lines changed

2 files changed

+39
-51
lines changed

pymc/sampling/mcmc.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _sample_external_nuts(
302302
initvals: StartDict | Sequence[StartDict | None] | None,
303303
model: Model,
304304
var_names: Sequence[str] | None,
305-
progressbar: bool | ProgressType,
305+
progressbar: bool,
306306
idata_kwargs: dict | None,
307307
compute_convergence_checks: bool,
308308
nuts_sampler_kwargs: dict | None,
@@ -401,7 +401,7 @@ def _sample_external_nuts(
401401
initvals=initvals,
402402
model=model,
403403
var_names=var_names,
404-
progressbar=True if progressbar else False,
404+
progressbar=progressbar,
405405
nuts_sampler=sampler,
406406
idata_kwargs=idata_kwargs,
407407
compute_convergence_checks=compute_convergence_checks,
@@ -423,7 +423,7 @@ def sample(
423423
chains: int | None = None,
424424
cores: int | None = None,
425425
random_seed: RandomState = None,
426-
progressbar: bool = True,
426+
progressbar: bool | ProgressType = True,
427427
progressbar_theme: Theme | None = default_progress_theme,
428428
step=None,
429429
var_names: Sequence[str] | None = None,
@@ -455,7 +455,7 @@ def sample(
455455
chains: int | None = None,
456456
cores: int | None = None,
457457
random_seed: RandomState = None,
458-
progressbar: bool = True,
458+
progressbar: bool | ProgressType = True,
459459
progressbar_theme: Theme | None = default_progress_theme,
460460
step=None,
461461
var_names: Sequence[str] | None = None,
@@ -540,17 +540,16 @@ def sample(
540540
easy spawning of new independent random streams that are needed by the step methods.
541541
progressbar: bool or ProgressType, optional
542542
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
543-
for either:
544-
- "combined": A single progress bar that displays the progress of all chains combined.
545-
- "chain": A separate progress bar for each chain.
546-
547-
You can also combine the above options with:
548-
- "simple": A simple progress bar that displays only timing information alongside the progress bar.
549-
- "full": A progress bar that displays all available statistics.
550-
551-
These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple".
552-
553-
If True, the default is "chain+full".
543+
for one of the following:
544+
- "combined": A single progress bar that displays the total progress across all chains. Only timing
545+
information is shown.
546+
- "split": A separate progress bar for each chain. Only timing information is shown.
547+
- "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
548+
chains. Aggregate sample statistics are also displayed.
549+
- "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
550+
are also displayed.
551+
552+
If True, the default is "split+stats" is used.
554553
step : function or iterable of functions
555554
A step function or collection of functions. If there are variables without step methods,
556555
step methods for those variables will be assigned automatically. By default the NUTS step
@@ -716,6 +715,10 @@ def sample(
716715
if isinstance(trace, list):
717716
raise ValueError("Please use `var_names` keyword argument for partial traces.")
718717

718+
# progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and
719+
# ADVI initialization expect just a bool.
720+
progress_bool = True if progressbar else False
721+
719722
model = modelcontext(model)
720723
if not model.free_RVs:
721724
raise SamplingError(
@@ -812,7 +815,7 @@ def joined_blas_limiter():
812815
initvals=initvals,
813816
model=model,
814817
var_names=var_names,
815-
progressbar=progressbar,
818+
progressbar=progress_bool,
816819
idata_kwargs=idata_kwargs,
817820
compute_convergence_checks=compute_convergence_checks,
818821
nuts_sampler_kwargs=nuts_sampler_kwargs,
@@ -831,9 +834,7 @@ def joined_blas_limiter():
831834
n_init=n_init,
832835
model=model,
833836
random_seed=random_seed_list,
834-
progressbar=True
835-
if progressbar
836-
else False, # ADVI doesn't use the ProgressManager; pass a bool only
837+
progressbar=progress_bool,
837838
jitter_max_retries=jitter_max_retries,
838839
tune=tune,
839840
initvals=initvals,

pymc/util.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,12 @@
5151

5252

5353
ProgressType = Literal[
54-
"chain",
5554
"combined",
56-
"simple",
57-
"full",
58-
"combined+full",
59-
"full+combined",
60-
"combined+simple",
61-
"simple+combined",
62-
"chain+full",
63-
"full+chain",
55+
"split",
56+
"combined+stats",
57+
"stats+combined",
58+
"split+stats",
59+
"stats+split",
6460
]
6561

6662

@@ -755,17 +751,16 @@ def __init__(
755751
Number of tuning steps per chain
756752
progressbar: bool or ProgressType, optional
757753
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
758-
for either:
759-
- "combined": A single progress bar that displays the progress of all chains combined.
760-
- "chain": A separate progress bar for each chain.
754+
for one of the following:
755+
- "combined": A single progress bar that displays the total progress across all chains. Only timing
756+
information is shown.
757+
- "split": A separate progress bar for each chain. Only timing information is shown.
758+
- "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
759+
chains. Aggregate sample statistics are also displayed.
760+
- "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
761+
are also displayed.
761762
762-
You can also combine the above options with:
763-
- "simple": A simple progress bar that displays only timing information alongside the progress bar.
764-
- "full": A progress bar that displays all available statistics.
765-
766-
These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple".
767-
768-
If True, the default is "chain+full".
763+
If True, the default is "split+stats" is used.
769764
770765
progressbar_theme: Theme, optional
771766
The theme to use for the progress bar. Defaults to the default theme.
@@ -784,28 +779,20 @@ def __init__(
784779
show_progress = False
785780
case "combined":
786781
self.combined_progress = True
787-
case "chain":
782+
self.full_stats = False
783+
case "split":
788784
self.combined_progress = False
789-
case "simple":
790785
self.full_stats = False
791-
case "full":
792-
self.full_stats = True
793-
case "combined+full" | "full+combined":
794-
self.combined_progress = True
786+
case "combined+stats" | "stats+combined":
795787
self.full_stats = True
796-
case "combined+simple" | "simple+combined":
797788
self.combined_progress = True
798-
self.full_stats = False
799-
case "chain+full" | "full+chain":
800-
self.combined_progress = False
789+
case "split+stats" | "stats+split":
801790
self.full_stats = True
802-
case "chain+simple" | "simple+chain":
803791
self.combined_progress = False
804-
self.full_stats = False
805792
case _:
806793
raise ValueError(
807794
"Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), "
808-
"one of 'combined', 'chain', 'simple', 'full', or a '+' delimited pair of two of these values."
795+
"one of 'combined', 'split', 'split+stats', or 'combined+stats."
809796
)
810797

811798
progress_columns, progress_stats = step_method._progressbar_config(chains)

0 commit comments

Comments
 (0)