@@ -168,28 +168,28 @@ def call_column(column, task):
168168 return table
169169
170170
171- class DivergenceBarColumn (BarColumn ):
172- """Rich colorbar that changes color when a chain has detected a divergence ."""
171+ class RecolorOnFailureBarColumn (BarColumn ):
172+ """Rich colorbar that changes color when a chain has detected a failure ."""
173173
174- def __init__ (self , * args , diverging_color = "red" , ** kwargs ):
174+ def __init__ (self , * args , failing_color = "red" , ** kwargs ):
175175 from matplotlib .colors import to_rgb
176176
177- self .diverging_color = diverging_color
178- self .diverging_rgb = [int (x * 255 ) for x in to_rgb (self .diverging_color )]
177+ self .failing_color = failing_color
178+ self .failing_rgb = [int (x * 255 ) for x in to_rgb (self .failing_color )]
179179
180180 super ().__init__ (* args , ** kwargs )
181181
182- self .non_diverging_style = self .complete_style
183- self .non_diverging_finished_style = self .finished_style
182+ self .default_complete_style = self .complete_style
183+ self .default_finished_style = self .finished_style
184184
185185 def callbacks (self , task : "Task" ):
186- divergences = task .fields .get ("divergences" , 0 )
187- if isinstance (divergences , float | int ) and divergences > 0 :
188- self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
189- self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
186+ if task .fields ["failing" ]:
187+ self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
188+ self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
190189 else :
191- self .complete_style = self .non_diverging_style
192- self .finished_style = self .non_diverging_finished_style
190+ # Recovered from failing yay
191+ self .complete_style = self .default_complete_style
192+ self .finished_style = self .default_finished_style
193193
194194
195195class ProgressBarManager :
@@ -284,7 +284,6 @@ def __init__(
284284 self .update_stats_functions = step_method ._make_progressbar_update_functions ()
285285
286286 self ._show_progress = show_progress
287- self .divergences = 0
288287 self .completed_draws = 0
289288 self .total_draws = draws + tune
290289 self .desc = "Sampling chain"
@@ -311,6 +310,7 @@ def _initialize_tasks(self):
311310 chain_idx = 0 ,
312311 sampling_speed = 0 ,
313312 speed_unit = "draws/s" ,
313+ failing = False ,
314314 ** {stat : value [0 ] for stat , value in self .progress_stats .items ()},
315315 )
316316 ]
@@ -325,6 +325,7 @@ def _initialize_tasks(self):
325325 chain_idx = chain_idx ,
326326 sampling_speed = 0 ,
327327 speed_unit = "draws/s" ,
328+ failing = False ,
328329 ** {stat : value [chain_idx ] for stat , value in self .progress_stats .items ()},
329330 )
330331 for chain_idx in range (self .chains )
@@ -354,42 +355,43 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
354355 elapsed = self ._progress .tasks [chain_idx ].elapsed
355356 speed , unit = self .compute_draw_speed (elapsed , draw )
356357
357- if not tuning and stats and stats [ 0 ]. get ( "diverging" ):
358- self . divergences += 1
358+ failing = False
359+ all_step_stats = {}
359360
360- if self .full_stats :
361- # TODO: Index by chain already?
362- chain_progress_stats = [
363- update_states_fn (step_stats )
364- for update_states_fn , step_stats in zip (
365- self .update_stats_functions , stats , strict = True
366- )
367- ]
368- all_step_stats = {}
369- for step_stats in chain_progress_stats :
370- for key , val in step_stats .items ():
371- if key in all_step_stats :
372- # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now
373- continue
374- else :
375- all_step_stats [key ] = val
376-
377- else :
378- all_step_stats = {}
361+ chain_progress_stats = [
362+ update_stats_fn (step_stats )
363+ for update_stats_fn , step_stats in zip (self .update_stats_functions , stats , strict = True )
364+ ]
365+ for step_stats in chain_progress_stats :
366+ for key , val in step_stats .items ():
367+ if key == "failing" :
368+ failing |= val
369+ continue
370+ if not self .full_stats :
371+ # Only care about the "failing" flag
372+ continue
373+
374+ if key in all_step_stats :
375+ # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now
376+ continue
377+ else :
378+ all_step_stats [key ] = val
379379
380380 self ._progress .update (
381381 self .tasks [chain_idx ],
382382 completed = draw ,
383383 draws = draw ,
384384 sampling_speed = speed ,
385385 speed_unit = unit ,
386+ failing = failing ,
386387 ** all_step_stats ,
387388 )
388389
389390 if is_last :
390391 self ._progress .update (
391392 self .tasks [chain_idx ],
392393 draws = draw + 1 if not self .combined_progress else draw ,
394+ failing = failing ,
393395 ** all_step_stats ,
394396 refresh = True ,
395397 )
@@ -410,9 +412,9 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
410412 ]
411413
412414 return CustomProgress (
413- DivergenceBarColumn (
415+ RecolorOnFailureBarColumn (
414416 table_column = Column ("Progress" , ratio = 2 ),
415- diverging_color = "tab:red" ,
417+ failing_color = "tab:red" ,
416418 complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
417419 finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
418420 ),
0 commit comments