@@ -702,30 +702,164 @@ def callbacks(self, task: "Task"):
702702 self .finished_style = self .non_diverging_finished_style
703703
704704
705- def create_progress_bar (step_columns , init_stat_dict , progressbar , progressbar_theme ):
706- columns = [TextColumn ("{task.fields[draws]}" , table_column = Column ("Draws" , ratio = 1 ))]
707- columns += step_columns
708- columns += [
709- TextColumn (
710- "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}" ,
711- table_column = Column ("Sampling Speed" , ratio = 1 ),
712- ),
713- TimeElapsedColumn (table_column = Column ("Elapsed" , ratio = 1 )),
714- TimeRemainingColumn (table_column = Column ("Remaining" , ratio = 1 )),
715- ]
716-
717- return CustomProgress (
718- DivergenceBarColumn (
719- table_column = Column ("Progress" , ratio = 2 ),
720- diverging_color = "tab:red" ,
721- complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
722- finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
723- ),
724- * columns ,
725- console = Console (theme = progressbar_theme ),
726- disable = not progressbar ,
727- include_headers = True ,
728- )
705+ class ProgressManager :
706+ def __init__ (self , step_method , chains , draws , tune , progressbar , progressbar_theme ):
707+ mode = "chain"
708+ stats = "full"
709+
710+ if isinstance (progressbar , bool ):
711+ show_progress = progressbar
712+ else :
713+ show_progress = True
714+
715+ if "+" in progressbar :
716+ mode , stats = progressbar .split ("+" )
717+ else :
718+ mode = progressbar
719+ stats = "full"
720+
721+ if mode not in ["chain" , "combined" ]:
722+ raise ValueError ('Invalid mode. Valid values are "chain" and "combined"' )
723+ if stats not in ["full" , "simple" ]:
724+ raise ValueError ('Invalid stats. Valid values are "full" and "simple"' )
725+
726+ progress_columns , progress_stats = step_method ._progressbar_config (chains )
727+ self .combined_progress = mode == "combined"
728+ self .full_stats = stats == "full"
729+
730+ self ._progress = self .create_progress_bar (
731+ progress_columns ,
732+ progressbar = progressbar ,
733+ progressbar_theme = progressbar_theme ,
734+ )
735+
736+ self .progress_stats = progress_stats
737+ self .update_stats = step_method ._make_update_stats_function ()
738+
739+ self ._show_progress = show_progress
740+ self .divergences = 0
741+ self .completed_draws = 0
742+ self .total_draws = draws + tune
743+ self .desc = "Sampling chain"
744+ self .chains = chains
745+
746+ self ._tasks : list [Task ] | None = None
747+
748+ def __enter__ (self ):
749+ self ._initialize_tasks ()
750+
751+ return self ._progress .__enter__ ()
752+
753+ def __exit__ (self , exc_type , exc_val , exc_tb ):
754+ return self ._progress .__exit__ (exc_type , exc_val , exc_tb )
755+
756+ def _initialize_tasks (self ):
757+ if self .combined_progress :
758+ self .tasks = [
759+ self ._progress .add_task (
760+ self .desc .format (self ),
761+ completed = 0 ,
762+ draws = 0 ,
763+ total = self .total_draws * self .chains - 1 ,
764+ chain_idx = 0 ,
765+ sampling_speed = 0 ,
766+ speed_unit = "draws/s" ,
767+ ** {stat : value [0 ] for stat , value in self .progress_stats .items ()},
768+ )
769+ ]
770+
771+ else :
772+ self .tasks = [
773+ self ._progress .add_task (
774+ self .desc .format (self ),
775+ completed = 0 ,
776+ draws = 0 ,
777+ total = self .total_draws - 1 ,
778+ chain_idx = chain_idx ,
779+ sampling_speed = 0 ,
780+ speed_unit = "draws/s" ,
781+ ** {stat : value [chain_idx ] for stat , value in self .progress_stats .items ()},
782+ )
783+ for chain_idx in range (self .chains )
784+ ]
785+
786+ def compute_draw_speed (self , chain_idx , draws ):
787+ elapsed = self ._progress .tasks [chain_idx ].elapsed
788+ speed = draws / max (elapsed , 1e-6 )
789+
790+ if speed > 1 or speed == 0 :
791+ unit = "draws/s"
792+ else :
793+ unit = "s/draws"
794+ speed = 1 / speed
795+
796+ return speed , unit
797+
798+ def update (self , chain_idx , is_last , draw , tuning , stats ):
799+ if not self ._show_progress :
800+ return
801+
802+ self .completed_draws += 1
803+ if self .combined_progress :
804+ draw = self .completed_draws
805+ chain_idx = 0
806+
807+ speed , unit = self .compute_draw_speed (chain_idx , draw )
808+
809+ if not tuning and stats and stats [0 ].get ("diverging" ):
810+ self .divergences += 1
811+
812+ self .progress_stats = self .update_stats (self .progress_stats , stats , chain_idx )
813+ more_updates = (
814+ {stat : value [chain_idx ] for stat , value in self .progress_stats .items ()}
815+ if self .full_stats
816+ else {}
817+ )
818+
819+ self ._progress .update (
820+ self .tasks [chain_idx ],
821+ completed = draw ,
822+ draws = draw ,
823+ sampling_speed = speed ,
824+ speed_unit = unit ,
825+ ** more_updates ,
826+ )
827+
828+ if is_last :
829+ self ._progress .update (
830+ self .tasks [chain_idx ],
831+ draws = draw + 1 if not self .combined_progress else draw - 1 ,
832+ ** more_updates ,
833+ refresh = True ,
834+ )
835+
836+ def create_progress_bar (self , step_columns , progressbar , progressbar_theme ):
837+ columns = [TextColumn ("{task.fields[draws]}" , table_column = Column ("Draws" , ratio = 1 ))]
838+
839+ if self .full_stats :
840+ columns += step_columns
841+
842+ columns += [
843+ TextColumn (
844+ "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}" ,
845+ table_column = Column ("Sampling Speed" , ratio = 1 ),
846+ ),
847+ TimeElapsedColumn (table_column = Column ("Elapsed" , ratio = 1 )),
848+ TimeRemainingColumn (table_column = Column ("Remaining" , ratio = 1 )),
849+ ]
850+
851+ return CustomProgress (
852+ DivergenceBarColumn (
853+ table_column = Column ("Progress" , ratio = 2 ),
854+ diverging_color = "tab:red" ,
855+ complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
856+ finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
857+ ),
858+ * columns ,
859+ console = Console (theme = progressbar_theme ),
860+ disable = not progressbar ,
861+ include_headers = True ,
862+ )
729863
730864
731865def compute_draw_speed (elapsed , draws ):
0 commit comments