Skip to content

Commit b9b0583

Browse files
mypy + cleanup
1 parent 4e535d4 commit b9b0583

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

pymc/sampling/mcmc.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _sample_external_nuts(
302302
initvals: StartDict | Sequence[StartDict | None] | None,
303303
model: Model,
304304
var_names: Sequence[str] | None,
305-
progressbar: bool,
305+
progressbar: bool | ProgressType,
306306
idata_kwargs: dict | None,
307307
compute_convergence_checks: bool,
308308
nuts_sampler_kwargs: dict | None,
@@ -401,7 +401,7 @@ def _sample_external_nuts(
401401
initvals=initvals,
402402
model=model,
403403
var_names=var_names,
404-
progressbar=progressbar,
404+
progressbar=True if progressbar else False,
405405
nuts_sampler=sampler,
406406
idata_kwargs=idata_kwargs,
407407
compute_convergence_checks=compute_convergence_checks,
@@ -488,7 +488,7 @@ def sample(
488488
cores: int | None = None,
489489
random_seed: RandomState = None,
490490
progressbar: bool | ProgressType = True,
491-
progressbar_theme: Theme | None = default_progress_theme,
491+
progressbar_theme: Theme | None = None,
492492
step=None,
493493
var_names: Sequence[str] | None = None,
494494
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
@@ -831,7 +831,9 @@ def joined_blas_limiter():
831831
n_init=n_init,
832832
model=model,
833833
random_seed=random_seed_list,
834-
progressbar=progressbar,
834+
progressbar=True
835+
if progressbar
836+
else False, # ADVI doesn't use the ProgressManager; pass a bool only
835837
jitter_max_retries=jitter_max_retries,
836838
tune=tune,
837839
initvals=initvals,

pymc/util.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from pymc.exceptions import BlockModelAccessError
4848

4949
if TYPE_CHECKING:
50-
from pymc import BlockedStep
50+
from pymc.step_methods.compound import BlockedStep, CompoundStep
5151

5252

5353
ProgressType = Literal[
@@ -727,12 +727,12 @@ class ProgressManager:
727727

728728
def __init__(
729729
self,
730-
step_method: BlockedStep,
730+
step_method: "BlockedStep" | "CompoundStep",
731731
chains: int,
732732
draws: int,
733733
tune: int,
734734
progressbar: bool | ProgressType = True,
735-
progressbar_theme: Theme = default_progress_theme,
735+
progressbar_theme: Theme | None = None,
736736
):
737737
"""
738738
Manage progress bars displayed during sampling.
@@ -770,6 +770,9 @@ def __init__(
770770
progressbar_theme: Theme, optional
771771
The theme to use for the progress bar. Defaults to the default theme.
772772
"""
773+
if progressbar_theme is None:
774+
progressbar_theme = default_progress_theme
775+
773776
self.combined_progress = False
774777
self.full_stats = True
775778
show_progress = True

0 commit comments

Comments
 (0)