@@ -657,22 +657,22 @@ class DivergenceBarColumn(BarColumn):
657657 def __init__ (self , * args , diverging_color = "red" , ** kwargs ):
658658 from matplotlib .colors import to_rgb
659659
660- self .diverging_color = diverging_color
661- self .diverging_rgb = [int (x * 255 ) for x in to_rgb (self .diverging_color )]
660+ self .failing_color = diverging_color
661+ self .failing_rgb = [int (x * 255 ) for x in to_rgb (self .failing_color )]
662662
663663 super ().__init__ (* args , ** kwargs )
664664
665- self .non_diverging_style = self .complete_style
666- self .non_diverging_finished_style = self .finished_style
665+ self .default_complete_style = self .complete_style
666+ self .default_finished_style = self .finished_style
667667
668668 def callbacks (self , task : "Task" ):
669- divergences = task .fields .get ("divergences" , 0 )
670- if isinstance (divergences , float | int ) and divergences > 0 :
671- self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
672- self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
669+ if task .fields ["failing" ]:
670+ self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
671+ self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
673672 else :
674- self .complete_style = self .non_diverging_style
675- self .finished_style = self .non_diverging_finished_style
673+ # Recovered from failing yay
674+ self .complete_style = self .default_complete_style
675+ self .finished_style = self .default_finished_style
676676
677677
678678class ProgressBarManager :
@@ -794,6 +794,7 @@ def _initialize_tasks(self):
794794 chain_idx = 0 ,
795795 sampling_speed = 0 ,
796796 speed_unit = "draws/s" ,
797+ failing = False ,
797798 ** {stat : value [0 ] for stat , value in self .progress_stats .items ()},
798799 )
799800 ]
@@ -808,6 +809,7 @@ def _initialize_tasks(self):
808809 chain_idx = chain_idx ,
809810 sampling_speed = 0 ,
810811 speed_unit = "draws/s" ,
812+ failing = False ,
811813 ** {stat : value [chain_idx ] for stat , value in self .progress_stats .items ()},
812814 )
813815 for chain_idx in range (self .chains )
@@ -829,16 +831,22 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
829831 self .divergences += 1
830832
831833 if self .full_stats :
834+ failing = False
835+ all_step_stats = {}
836+
832837 # TODO: Index by chain already?
833838 chain_progress_stats = [
834839 update_states_fn (step_stats )
835840 for update_states_fn , step_stats in zip (
836841 self .update_stats_functions , stats , strict = True
837842 )
838843 ]
839- all_step_stats = {}
840844 for step_stats in chain_progress_stats :
841845 for key , val in step_stats .items ():
846+ if key == "failing" :
847+ failing |= val
848+ continue
849+
842850 if key in all_step_stats :
843851 continue
844852 count = (
@@ -849,6 +857,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
849857 all_step_stats [key ] = val
850858
851859 else :
860+ failing = False
852861 all_step_stats = {}
853862
854863 # more_updates = (
@@ -863,6 +872,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
863872 draws = draw ,
864873 sampling_speed = speed ,
865874 speed_unit = unit ,
875+ failing = failing ,
866876 ** all_step_stats ,
867877 )
868878
0 commit comments