Skip to content

Commit 4e535d4

Browse files
Update docstrings
1 parent e024991 commit 4e535d4

File tree

2 files changed

+83
-8
lines changed

2 files changed

+83
-8
lines changed

pymc/sampling/mcmc.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from pymc.step_methods.hmc import quadpotential
6767
from pymc.util import (
6868
ProgressManager,
69+
ProgressType,
6970
RandomSeed,
7071
RandomState,
7172
_get_seeds_per_chain,
@@ -486,7 +487,7 @@ def sample(
486487
chains: int | None = None,
487488
cores: int | None = None,
488489
random_seed: RandomState = None,
489-
progressbar: bool = True,
490+
progressbar: bool | ProgressType = True,
490491
progressbar_theme: Theme | None = default_progress_theme,
491492
step=None,
492493
var_names: Sequence[str] | None = None,
@@ -537,11 +538,19 @@ def sample(
537538
A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
538539
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
539540
easy spawning of new independent random streams that are needed by the step methods.
540-
progressbar : bool, optional default=True
541-
Whether or not to display a progress bar in the command line. The bar shows the percentage
542-
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
543-
time until completion ("expected time of arrival"; ETA).
544-
Only applicable to the pymc nuts sampler.
541+
progressbar: bool or ProgressType, optional
542+
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
543+
for either:
544+
- "combined": A single progress bar that displays the progress of all chains combined.
545+
- "chain": A separate progress bar for each chain.
546+
547+
You can also combine the above options with:
548+
- "simple": A simple progress bar that displays only timing information alongside the progress bar.
549+
- "full": A progress bar that displays all available statistics.
550+
551+
These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple".
552+
553+
If True, the default is "chain+full".
545554
step : function or iterable of functions
546555
A step function or collection of functions. If there are variables without step methods,
547556
step methods for those variables will be assigned automatically. By default the NUTS step

pymc/util.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from collections import namedtuple
2020
from collections.abc import Iterable, Sequence
2121
from copy import deepcopy
22-
from typing import NewType, cast
22+
from typing import TYPE_CHECKING, Literal, NewType, cast
2323

2424
import arviz
2525
import cloudpickle
@@ -46,6 +46,23 @@
4646

4747
from pymc.exceptions import BlockModelAccessError
4848

49+
if TYPE_CHECKING:
50+
from pymc import BlockedStep
51+
52+
53+
ProgressType = Literal[
54+
"chain",
55+
"combined",
56+
"simple",
57+
"full",
58+
"combined+full",
59+
"full+combined",
60+
"combined+simple",
61+
"simple+combined",
62+
"chain+full",
63+
"full+chain",
64+
]
65+
4966

5067
def __getattr__(name):
5168
if name == "dataset_to_point_list":
@@ -639,6 +656,7 @@ def make_tasks_table(self, tasks: Iterable[Task]) -> Table:
639656
"""
640657

641658
def call_column(column, task):
659+
# Subclass rich.BarColumn and add a callback method to dynamically update the display
642660
if hasattr(column, "callbacks"):
643661
column.callbacks(task)
644662

@@ -681,6 +699,8 @@ def call_column(column, task):
681699

682700

683701
class DivergenceBarColumn(BarColumn):
702+
"""Rich colorbar that changes color when a chain has detected a divergence."""
703+
684704
def __init__(self, *args, diverging_color="red", **kwargs):
685705
from matplotlib.colors import to_rgb
686706

@@ -703,7 +723,53 @@ def callbacks(self, task: "Task"):
703723

704724

705725
class ProgressManager:
706-
def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_theme):
726+
"""Manage progress bars displayed during sampling."""
727+
728+
def __init__(
729+
self,
730+
step_method: BlockedStep,
731+
chains: int,
732+
draws: int,
733+
tune: int,
734+
progressbar: bool | ProgressType = True,
735+
progressbar_theme: Theme = default_progress_theme,
736+
):
737+
"""
738+
Manage progress bars displayed during sampling.
739+
740+
When sampling, Step classes are responsible for computing and exposing statistics that can be reported on
741+
progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config`
742+
and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which
743+
columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics
744+
that will be displayed on the progress bar.
745+
746+
Parameters
747+
----------
748+
step_method: BlockedStep
749+
The step method being used to sample
750+
chains: int
751+
Number of chains being sampled
752+
draws: int
753+
Number of draws per chain
754+
tune: int
755+
Number of tuning steps per chain
756+
progressbar: bool or ProgressType, optional
757+
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
758+
for either:
759+
- "combined": A single progress bar that displays the progress of all chains combined.
760+
- "chain": A separate progress bar for each chain.
761+
762+
You can also combine the above options with:
763+
- "simple": A simple progress bar that displays only timing information alongside the progress bar.
764+
- "full": A progress bar that displays all available statistics.
765+
766+
These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple".
767+
768+
If True, the default is "chain+full".
769+
770+
progressbar_theme: Theme, optional
771+
The theme to use for the progress bar. Defaults to the default theme.
772+
"""
707773
self.combined_progress = False
708774
self.full_stats = True
709775
show_progress = True

0 commit comments

Comments
 (0)