@@ -302,7 +302,7 @@ def _sample_external_nuts(
302302 initvals : StartDict | Sequence [StartDict | None ] | None ,
303303 model : Model ,
304304 var_names : Sequence [str ] | None ,
305- progressbar : bool | ProgressType ,
305+ progressbar : bool ,
306306 idata_kwargs : dict | None ,
307307 compute_convergence_checks : bool ,
308308 nuts_sampler_kwargs : dict | None ,
@@ -401,7 +401,7 @@ def _sample_external_nuts(
401401 initvals = initvals ,
402402 model = model ,
403403 var_names = var_names ,
404- progressbar = True if progressbar else False ,
404+ progressbar = progressbar ,
405405 nuts_sampler = sampler ,
406406 idata_kwargs = idata_kwargs ,
407407 compute_convergence_checks = compute_convergence_checks ,
@@ -423,7 +423,7 @@ def sample(
423423 chains : int | None = None ,
424424 cores : int | None = None ,
425425 random_seed : RandomState = None ,
426- progressbar : bool = True ,
426+ progressbar : bool | ProgressType = True ,
427427 progressbar_theme : Theme | None = default_progress_theme ,
428428 step = None ,
429429 var_names : Sequence [str ] | None = None ,
@@ -455,7 +455,7 @@ def sample(
455455 chains : int | None = None ,
456456 cores : int | None = None ,
457457 random_seed : RandomState = None ,
458- progressbar : bool = True ,
458+ progressbar : bool | ProgressType = True ,
459459 progressbar_theme : Theme | None = default_progress_theme ,
460460 step = None ,
461461 var_names : Sequence [str ] | None = None ,
@@ -540,17 +540,16 @@ def sample(
540540 easy spawning of new independent random streams that are needed by the step methods.
541541 progressbar: bool or ProgressType, optional
542542 How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
543- for either:
544- - "combined": A single progress bar that displays the progress of all chains combined.
545- - "chain": A separate progress bar for each chain.
546-
547- You can also combine the above options with:
548- - "simple": A simple progress bar that displays only timing information alongside the progress bar.
549- - "full": A progress bar that displays all available statistics.
550-
551- These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple".
552-
553- If True, the default is "chain+full".
543+ for one of the following:
544+ - "combined": A single progress bar that displays the total progress across all chains. Only timing
545+ information is shown.
546+ - "split": A separate progress bar for each chain. Only timing information is shown.
547+ - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
548+ chains. Aggregate sample statistics are also displayed.
549+ - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
550+ are also displayed.
551+
552+ If True, the default is "split+stats" is used.
554553 step : function or iterable of functions
555554 A step function or collection of functions. If there are variables without step methods,
556555 step methods for those variables will be assigned automatically. By default the NUTS step
@@ -716,6 +715,10 @@ def sample(
716715 if isinstance (trace , list ):
717716 raise ValueError ("Please use `var_names` keyword argument for partial traces." )
718717
718+ # progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and
719+ # ADVI initialization expect just a bool.
720+ progress_bool = True if progressbar else False
721+
719722 model = modelcontext (model )
720723 if not model .free_RVs :
721724 raise SamplingError (
@@ -812,7 +815,7 @@ def joined_blas_limiter():
812815 initvals = initvals ,
813816 model = model ,
814817 var_names = var_names ,
815- progressbar = progressbar ,
818+ progressbar = progress_bool ,
816819 idata_kwargs = idata_kwargs ,
817820 compute_convergence_checks = compute_convergence_checks ,
818821 nuts_sampler_kwargs = nuts_sampler_kwargs ,
@@ -831,9 +834,7 @@ def joined_blas_limiter():
831834 n_init = n_init ,
832835 model = model ,
833836 random_seed = random_seed_list ,
834- progressbar = True
835- if progressbar
836- else False , # ADVI doesn't use the ProgressManager; pass a bool only
837+ progressbar = progress_bool ,
837838 jitter_max_retries = jitter_max_retries ,
838839 tune = tune ,
839840 initvals = initvals ,
0 commit comments