2828import numpy as np
2929
3030from rich .console import Console
31- from rich .progress import BarColumn , TextColumn , TimeElapsedColumn , TimeRemainingColumn
31+ from rich .progress import TextColumn
32+ from rich .style import Style
33+ from rich .table import Column
3234from rich .theme import Theme
3335from threadpoolctl import threadpool_limits
3436
3739from pymc .exceptions import SamplingError
3840from pymc .util import (
3941 CustomProgress ,
42+ DivergenceBarColumn ,
4043 RandomGeneratorState ,
4144 default_progress_theme ,
4245 get_state_from_generator ,
@@ -487,20 +490,35 @@ def __init__(
487490 self ._in_context = False
488491
489492 self ._progress = CustomProgress (
490- "[progress.description]{task.description}" ,
491- BarColumn (),
492- "[progress.percentage]{task.percentage:>3.0f}%" ,
493- TimeRemainingColumn (),
494- TextColumn ("/" ),
495- TimeElapsedColumn (),
493+ DivergenceBarColumn (
494+ table_column = Column ("Progress" , ratio = 2 ),
495+ diverging_color = "tab:red" ,
496+ diverging_finished_color = "tab:purple" ,
497+ complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
498+ finished_style = Style .parse ("rgb(44,160,44)" ), # tab:green
499+ ),
500+ TextColumn ("{task.fields[draws]:,d}" , table_column = Column ("Draws" , ratio = 1 )),
501+ TextColumn (
502+ "{task.fields[divergences]:,d}" , table_column = Column ("Divergences" , ratio = 1 )
503+ ),
504+ TextColumn ("{task.fields[step_size]:0.2f}" , table_column = Column ("Step size" , ratio = 1 )),
505+ TextColumn ("{task.fields[tree_depth]:,d}" , table_column = Column ("Tree depth" , ratio = 1 )),
506+ TextColumn (
507+ "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}" ,
508+ table_column = Column ("Sampling Speed" , ratio = 1 ),
509+ ),
496510 console = Console (theme = progressbar_theme ),
497511 disable = not progressbar ,
512+ include_headers = True ,
498513 )
514+
499515 self ._show_progress = progressbar
500516 self ._divergences = 0
517+ self ._divergences_by_chain = [0 ] * chains
501518 self ._completed_draws = 0
502- self ._total_draws = chains * (draws + tune )
503- self ._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences"
519+ self ._completed_draws_by_chain = [0 ] * chains
520+ self ._total_draws = draws + tune
521+ self ._desc = "Sampling chain"
504522 self ._chains = chains
505523
506524 def _make_active (self ):
@@ -517,31 +535,71 @@ def __iter__(self):
517535 self ._make_active ()
518536
519537 with self ._progress as progress :
520- task = progress .add_task (
521- self ._desc .format (self ),
522- completed = self ._completed_draws ,
523- total = self ._total_draws ,
524- )
538+ tasks = [
539+ progress .add_task (
540+ self ._desc .format (self ),
541+ completed = self ._completed_draws ,
542+ total = self ._total_draws ,
543+ chain_idx = chain_idx ,
544+ draws = 0 ,
545+ divergences = 0 ,
546+ step_size = 0.0 ,
547+ tree_depth = 0 ,
548+ sampling_speed = 0 ,
549+ speed_unit = "draws/s" ,
550+ )
551+ for chain_idx in range (self ._chains )
552+ ]
525553
526554 while self ._active :
527555 draw = ProcessAdapter .recv_draw (self ._active )
528556 proc , is_last , draw , tuning , stats = draw
557+ speed = 0
558+ unit = "draws/s"
559+
529560 self ._completed_draws += 1
561+ self ._completed_draws_by_chain [proc .chain ] += 1
562+
530563 if not tuning and stats and stats [0 ].get ("diverging" ):
531564 self ._divergences += 1
565+ self ._divergences_by_chain [proc .chain ] += 1
566+
567+ if self ._show_progress :
568+ elapsed = progress ._tasks [proc .chain ].elapsed
569+ speed = self ._completed_draws_by_chain [proc .chain ] / elapsed
570+
571+ if speed > 1 :
572+ unit = "draws/s"
573+ else :
574+ unit = "s/draws"
575+ speed = 1 / speed
576+
532577 progress .update (
533- task ,
534- completed = self ._completed_draws ,
535- total = self ._total_draws ,
536- description = self ._desc .format (self ),
578+ tasks [proc .chain ],
579+ completed = self ._completed_draws_by_chain [proc .chain ],
580+ draws = draw ,
581+ divergences = self ._divergences_by_chain [proc .chain ],
582+ step_size = stats [0 ].get ("step_size" , 0 ),
583+ tree_depth = stats [0 ].get ("tree_size" , 0 ),
584+ sampling_speed = speed ,
585+ speed_unit = unit ,
537586 )
538587
539588 if is_last :
589+ self ._completed_draws_by_chain [proc .chain ] += 1
590+
540591 proc .join ()
541592 self ._active .remove (proc )
542593 self ._finished .append (proc )
543594 self ._make_active ()
544- progress .update (task , description = self ._desc .format (self ), refresh = True )
595+ progress .update (
596+ tasks [proc .chain ],
597+ draws = draw + 1 ,
598+ divergences = self ._divergences_by_chain [proc .chain ],
599+ step_size = stats [0 ].get ("step_size" , 0 ),
600+ tree_depth = stats [0 ].get ("tree_size" , 0 ),
601+ refresh = True ,
602+ )
545603
546604 # We could also yield proc.shared_point_view directly,
547605 # and only call proc.write_next() after the yield returns.
0 commit comments