Skip to content

Commit 3180533

Browse files
Remove reference to divergence in progress bar
1 parent 4b4c28d commit 3180533

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

pymc/progress_bar.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,13 @@ def call_column(column, task):
153153
return table
154154

155155

156-
class DivergenceBarColumn(BarColumn):
157-
"""Rich colorbar that changes color when a chain has detected a divergence."""
156+
class RecolorOnFailureBarColumn(BarColumn):
157+
"""Rich colorbar that changes color when a chain has detected a failure."""
158158

159-
def __init__(self, *args, diverging_color="red", **kwargs):
159+
def __init__(self, *args, failing_color="red", **kwargs):
160160
from matplotlib.colors import to_rgb
161161

162-
self.failing_color = diverging_color
162+
self.failing_color = failing_color
163163
self.failing_rgb = [int(x * 255) for x in to_rgb(self.failing_color)]
164164

165165
super().__init__(*args, **kwargs)
@@ -269,7 +269,6 @@ def __init__(
269269
self.update_stats_functions = step_method._make_progressbar_update_functions()
270270

271271
self._show_progress = show_progress
272-
self.divergences = 0
273272
self.completed_draws = 0
274273
self.total_draws = draws + tune
275274
self.desc = "Sampling chain"
@@ -341,17 +340,13 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
341340
elapsed = self._progress.tasks[chain_idx].elapsed
342341
speed, unit = self.compute_draw_speed(elapsed, draw)
343342

344-
if not tuning and stats and stats[0].get("diverging"):
345-
self.divergences += 1
346-
347343
if self.full_stats:
348344
failing = False
349345
all_step_stats = {}
350346

351-
# TODO: Index by chain already?
352347
chain_progress_stats = [
353-
update_states_fn(step_stats)
354-
for update_states_fn, step_stats in zip(
348+
update_stats_fn(step_stats)
349+
for update_stats_fn, step_stats in zip(
355350
self.update_stats_functions, stats, strict=True
356351
)
357352
]
@@ -405,9 +400,9 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
405400
]
406401

407402
return CustomProgress(
408-
DivergenceBarColumn(
403+
RecolorOnFailureBarColumn(
409404
table_column=Column("Progress", ratio=2),
410-
diverging_color="tab:red",
405+
failure_color="tab:red",
411406
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
412407
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue
413408
),

0 commit comments

Comments
 (0)