Skip to content

Commit 161d10c

Browse files
Incorporate feedback
1 parent 79d1248 commit 161d10c

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

pymc/sampling/mcmc.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@
6565
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
6666
from pymc.step_methods.hmc import quadpotential
6767
from pymc.util import (
68-
ProgressManager,
69-
ProgressType,
68+
ProgressBarManager,
69+
ProgressBarType,
7070
RandomSeed,
7171
RandomState,
7272
_get_seeds_per_chain,
@@ -423,7 +423,7 @@ def sample(
423423
chains: int | None = None,
424424
cores: int | None = None,
425425
random_seed: RandomState = None,
426-
progressbar: bool | ProgressType = True,
426+
progressbar: bool | ProgressBarType = True,
427427
progressbar_theme: Theme | None = default_progress_theme,
428428
step=None,
429429
var_names: Sequence[str] | None = None,
@@ -455,7 +455,7 @@ def sample(
455455
chains: int | None = None,
456456
cores: int | None = None,
457457
random_seed: RandomState = None,
458-
progressbar: bool | ProgressType = True,
458+
progressbar: bool | ProgressBarType = True,
459459
progressbar_theme: Theme | None = default_progress_theme,
460460
step=None,
461461
var_names: Sequence[str] | None = None,
@@ -487,7 +487,7 @@ def sample(
487487
chains: int | None = None,
488488
cores: int | None = None,
489489
random_seed: RandomState = None,
490-
progressbar: bool | ProgressType = True,
490+
progressbar: bool | ProgressBarType = True,
491491
progressbar_theme: Theme | None = None,
492492
step=None,
493493
var_names: Sequence[str] | None = None,
@@ -717,7 +717,7 @@ def sample(
717717

718718
# progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and
719719
# ADVI initialization expect just a bool.
720-
progress_bool = True if progressbar else False
720+
progress_bool = bool(progressbar)
721721

722722
model = modelcontext(model)
723723
if not model.free_RVs:
@@ -1148,7 +1148,7 @@ def _sample_many(
11481148
Step function
11491149
"""
11501150
initial_step_state = step.sampling_state
1151-
progress_manager = ProgressManager(
1151+
progress_manager = ProgressBarManager(
11521152
step_method=step,
11531153
chains=chains,
11541154
draws=draws - kwargs.get("tune", 0),
@@ -1185,7 +1185,7 @@ def _sample(
11851185
tune: int,
11861186
model: Model | None = None,
11871187
callback=None,
1188-
progress_manager: ProgressManager,
1188+
progress_manager: ProgressBarManager,
11891189
**kwargs,
11901190
) -> None:
11911191
"""Sample one chain (singleprocess).
@@ -1210,7 +1210,7 @@ def _sample(
12101210
Number of iterations to tune.
12111211
model : Model, optional
12121212
PyMC model. If None, the model is taken from the current context.
1213-
progress_manager: ProgressManager
1213+
progress_manager: ProgressBarManager
12141214
Helper class used to handle progress bar styling and updates
12151215
"""
12161216
sampling_gen = _iter_sample(

pymc/sampling/parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from pymc.blocking import DictToArrayBijection
3535
from pymc.exceptions import SamplingError
3636
from pymc.util import (
37-
ProgressManager,
37+
ProgressBarManager,
3838
RandomGeneratorState,
3939
default_progress_theme,
4040
get_state_from_generator,
@@ -483,7 +483,7 @@ def __init__(
483483
self._max_active = cores
484484

485485
self._in_context = False
486-
self._progress = ProgressManager(
486+
self._progress = ProgressBarManager(
487487
step_method=step_method,
488488
chains=chains,
489489
draws=draws,

pymc/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from pymc.step_methods.compound import BlockedStep, CompoundStep
5151

5252

53-
ProgressType = Literal[
53+
ProgressBarType = Literal[
5454
"combined",
5555
"split",
5656
"combined+stats",
@@ -718,7 +718,7 @@ def callbacks(self, task: "Task"):
718718
self.finished_style = self.non_diverging_finished_style
719719

720720

721-
class ProgressManager:
721+
class ProgressBarManager:
722722
"""Manage progress bars displayed during sampling."""
723723

724724
def __init__(
@@ -727,7 +727,7 @@ def __init__(
727727
chains: int,
728728
draws: int,
729729
tune: int,
730-
progressbar: bool | ProgressType = True,
730+
progressbar: bool | ProgressBarType = True,
731731
progressbar_theme: Theme | None = None,
732732
):
733733
"""

0 commit comments

Comments
 (0)