@@ -302,7 +302,7 @@ def _sample_external_nuts(
302302 initvals : StartDict | Sequence [StartDict | None ] | None ,
303303 model : Model ,
304304 var_names : Sequence [str ] | None ,
305- progressbar : bool ,
305+ progressbar : bool | ProgressType ,
306306 idata_kwargs : dict | None ,
307307 compute_convergence_checks : bool ,
308308 nuts_sampler_kwargs : dict | None ,
@@ -401,7 +401,7 @@ def _sample_external_nuts(
401401 initvals = initvals ,
402402 model = model ,
403403 var_names = var_names ,
404- progressbar = progressbar ,
404+ progressbar = True if progressbar else False ,
405405 nuts_sampler = sampler ,
406406 idata_kwargs = idata_kwargs ,
407407 compute_convergence_checks = compute_convergence_checks ,
@@ -488,7 +488,7 @@ def sample(
488488 cores : int | None = None ,
489489 random_seed : RandomState = None ,
490490 progressbar : bool | ProgressType = True ,
491- progressbar_theme : Theme | None = default_progress_theme ,
491+ progressbar_theme : Theme | None = None ,
492492 step = None ,
493493 var_names : Sequence [str ] | None = None ,
494494 nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
@@ -831,7 +831,9 @@ def joined_blas_limiter():
831831 n_init = n_init ,
832832 model = model ,
833833 random_seed = random_seed_list ,
834- progressbar = progressbar ,
834+ progressbar = True
835+ if progressbar
836+ else False , # ADVI doesn't use the ProgressManager; pass a bool only
835837 jitter_max_retries = jitter_max_retries ,
836838 tune = tune ,
837839 initvals = initvals ,
0 commit comments