@@ -153,13 +153,13 @@ def call_column(column, task):
153153 return table
154154
155155
156- class DivergenceBarColumn (BarColumn ):
157- """Rich colorbar that changes color when a chain has detected a divergence ."""
156+ class RecolorOnFailureBarColumn (BarColumn ):
157+ """Rich colorbar that changes color when a chain has detected a failure ."""
158158
159- def __init__ (self , * args , diverging_color = "red" , ** kwargs ):
159+ def __init__ (self , * args , failing_color = "red" , ** kwargs ):
160160 from matplotlib .colors import to_rgb
161161
162- self .failing_color = diverging_color
162+ self .failing_color = failing_color
163163 self .failing_rgb = [int (x * 255 ) for x in to_rgb (self .failing_color )]
164164
165165 super ().__init__ (* args , ** kwargs )
@@ -269,7 +269,6 @@ def __init__(
269269 self .update_stats_functions = step_method ._make_progressbar_update_functions ()
270270
271271 self ._show_progress = show_progress
272- self .divergences = 0
273272 self .completed_draws = 0
274273 self .total_draws = draws + tune
275274 self .desc = "Sampling chain"
@@ -341,17 +340,13 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
341340 elapsed = self ._progress .tasks [chain_idx ].elapsed
342341 speed , unit = self .compute_draw_speed (elapsed , draw )
343342
344- if not tuning and stats and stats [0 ].get ("diverging" ):
345- self .divergences += 1
346-
347343 if self .full_stats :
348344 failing = False
349345 all_step_stats = {}
350346
351- # TODO: Index by chain already?
352347 chain_progress_stats = [
353- update_states_fn (step_stats )
354- for update_states_fn , step_stats in zip (
348+ update_stats_fn (step_stats )
349+ for update_stats_fn , step_stats in zip (
355350 self .update_stats_functions , stats , strict = True
356351 )
357352 ]
@@ -405,9 +400,9 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
405400 ]
406401
407402 return CustomProgress (
408- DivergenceBarColumn (
403+ RecolorOnFailureBarColumn (
409404 table_column = Column ("Progress" , ratio = 2 ),
410- diverging_color = "tab:red" ,
405+ failure_color = "tab:red" ,
411406 complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
412407 finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
413408 ),
0 commit comments