Skip to content

Commit 741cf36

Browse files
Explicit case handling for progressbar argument
1 parent 345faff commit 741cf36

File tree

1 file changed

+30
-20
lines changed

1 file changed

+30
-20
lines changed

pymc/util.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -704,28 +704,38 @@ def callbacks(self, task: "Task"):
704704

705705
class ProgressManager:
706706
def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_theme):
707-
mode = "chain"
708-
stats = "full"
709-
710-
if isinstance(progressbar, bool):
711-
show_progress = progressbar
712-
else:
713-
show_progress = True
714-
715-
if "+" in progressbar:
716-
mode, stats = progressbar.split("+")
717-
else:
718-
mode = progressbar
719-
stats = "full"
720-
721-
if mode not in ["chain", "combined"]:
722-
raise ValueError('Invalid mode. Valid values are "chain" and "combined"')
723-
if stats not in ["full", "simple"]:
724-
raise ValueError('Invalid stats. Valid values are "full" and "simple"')
707+
self.combined_progress = False
708+
self.full_stats = True
709+
show_progress = True
710+
711+
match progressbar:
712+
case True:
713+
show_progress = True
714+
case False:
715+
show_progress = False
716+
case "combined":
717+
self.combined_progress = True
718+
case "chain":
719+
self.combined_progress = False
720+
case "combined+full":
721+
self.combined_progress = True
722+
self.full_stats = True
723+
case "combined+simple":
724+
self.combined_progress = True
725+
self.full_stats = False
726+
case "chain+full":
727+
self.combined_progress = False
728+
self.full_stats = True
729+
case "chain+simple":
730+
self.combined_progress = False
731+
self.full_stats = False
732+
case _:
733+
raise ValueError(
734+
"Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), "
735+
"or one of 'combined', 'chain', 'combined+full', 'combined+simple', 'chain+full', 'chain+simple'."
736+
)
725737

726738
progress_columns, progress_stats = step_method._progressbar_config(chains)
727-
self.combined_progress = mode == "combined"
728-
self.full_stats = stats == "full"
729739

730740
self._progress = self.create_progress_bar(
731741
progress_columns,

0 commit comments

Comments
 (0)