12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from collections .abc import Iterable
15
- from typing import TYPE_CHECKING , Literal
15
+ from typing import TYPE_CHECKING , Literal , Protocol , cast
16
16
17
17
from rich .box import SIMPLE_HEAD
18
18
from rich .console import Console
@@ -192,6 +192,132 @@ def callbacks(self, task: "Task"):
192
192
self .finished_style = self .default_finished_style
193
193
194
194
195
+ class ProgressBar (Protocol ):
196
+ @property
197
+ def tasks (self ):
198
+ """Get the tasks in the progress bar."""
199
+
200
+ def add_task (self , * args , ** kwargs ):
201
+ """Add a task to the progress bar."""
202
+
203
+ def update (self , task_id , ** kwargs ):
204
+ """Update the task with the given ID with the provided keyword arguments."""
205
+
206
+ def __enter__ (self ):
207
+ """Enter the context manager."""
208
+
209
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
210
+ """Exit the context manager."""
211
+
212
+
213
+ def compute_draw_speed (elapsed , draws ):
214
+ speed = draws / max (elapsed , 1e-6 )
215
+
216
+ if speed > 1 or speed == 0 :
217
+ unit = "draws/s"
218
+ else :
219
+ unit = "s/draws"
220
+ speed = 1 / speed
221
+
222
+ return speed , unit
223
+
224
+
225
+ def create_rich_progress_bar (full_stats , step_columns , progressbar , progressbar_theme ):
226
+ columns = [TextColumn ("{task.fields[draws]}" , table_column = Column ("Draws" , ratio = 1 ))]
227
+
228
+ if full_stats :
229
+ columns += step_columns
230
+
231
+ columns += [
232
+ TextColumn (
233
+ "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}" ,
234
+ table_column = Column ("Sampling Speed" , ratio = 1 ),
235
+ ),
236
+ TimeElapsedColumn (table_column = Column ("Elapsed" , ratio = 1 )),
237
+ TimeRemainingColumn (table_column = Column ("Remaining" , ratio = 1 )),
238
+ ]
239
+
240
+ return CustomProgress (
241
+ RecolorOnFailureBarColumn (
242
+ table_column = Column ("Progress" , ratio = 2 ),
243
+ failing_color = "tab:red" ,
244
+ complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
245
+ finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
246
+ ),
247
+ * columns ,
248
+ console = Console (theme = progressbar_theme ),
249
+ disable = not progressbar ,
250
+ include_headers = True ,
251
+ )
252
+
253
+
254
+ class MarimoProgressTask :
255
+ def __init__ (self , * args , ** kwargs ):
256
+ self .args = args
257
+ self .kwargs = kwargs
258
+
259
+ @property
260
+ def chain_idx (self ) -> int :
261
+ return self .kwargs .get ("chain_idx" , 0 )
262
+
263
+ @property
264
+ def total (self ):
265
+ return self .kwargs .get ("total" , 0 )
266
+
267
+ @property
268
+ def elapsed (self ):
269
+ return self .kwargs .get ("elapsed" , 0 )
270
+
271
+
272
+ class MarimoProgressBar :
273
+ def __init__ (self ) -> None :
274
+ self .tasks = []
275
+ self .divergences = {}
276
+
277
+ def __enter__ (self ):
278
+ """Enter the context manager."""
279
+ import marimo as mo
280
+
281
+ total_draws = (self .tasks [0 ].total + 1 ) * len (self .tasks )
282
+
283
+ self .bar = mo .status .progress_bar (total = total_draws , title = "Sampling PyMC model" )
284
+
285
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
286
+ """Exit the context manager."""
287
+ self .bar ._finish ()
288
+
289
+ def add_task (self , * args , ** kwargs ):
290
+ """Add a task to the progress bar."""
291
+ task = MarimoProgressTask (* args , ** kwargs )
292
+ self .tasks .append (task )
293
+ return task
294
+
295
+ def update (self , task_id , ** kwargs ):
296
+ """Update the task with the given ID with the provided keyword arguments."""
297
+ if self .bar .progress .current >= cast (int , self .bar .progress .total ):
298
+ return
299
+
300
+ self .divergences [task_id .chain_idx ] = kwargs .get ("divergences" , 0 )
301
+
302
+ total_divergences = sum (self .divergences .values ())
303
+
304
+ update_kwargs = {}
305
+ if total_divergences > 0 :
306
+ word = "draws" if total_divergences > 1 else "draw"
307
+ update_kwargs ["subtitle" ] = f"{ total_divergences } diverging { word } "
308
+
309
+ self .bar .progress .update (** update_kwargs )
310
+
311
+
312
+ def in_marimo_notebook () -> bool :
313
+ try :
314
+ import marimo as mo
315
+
316
+ return mo .running_in_notebook ()
317
+ except ImportError :
318
+ return False
319
+
320
+
195
321
class ProgressBarManager :
196
322
"""Manage progress bars displayed during sampling."""
197
323
@@ -203,6 +329,7 @@ def __init__(
203
329
tune : int ,
204
330
progressbar : bool | ProgressBarType = True ,
205
331
progressbar_theme : Theme | None = None ,
332
+ progress : ProgressBar | None = None ,
206
333
):
207
334
"""
208
335
Manage progress bars displayed during sampling.
@@ -275,11 +402,16 @@ def __init__(
275
402
276
403
progress_columns , progress_stats = step_method ._progressbar_config (chains )
277
404
278
- self ._progress = self .create_progress_bar (
279
- progress_columns ,
280
- progressbar = progressbar ,
281
- progressbar_theme = progressbar_theme ,
282
- )
405
+ if in_marimo_notebook ():
406
+ self .combined_progress = False
407
+ self ._progress = MarimoProgressBar ()
408
+ else :
409
+ self ._progress = progress or create_rich_progress_bar (
410
+ self .full_stats ,
411
+ progress_columns ,
412
+ progressbar = progressbar ,
413
+ progressbar_theme = progressbar_theme ,
414
+ )
283
415
self .progress_stats = progress_stats
284
416
self .update_stats_functions = step_method ._make_progressbar_update_functions ()
285
417
@@ -331,18 +463,6 @@ def _initialize_tasks(self):
331
463
for chain_idx in range (self .chains )
332
464
]
333
465
334
- @staticmethod
335
- def compute_draw_speed (elapsed , draws ):
336
- speed = draws / max (elapsed , 1e-6 )
337
-
338
- if speed > 1 or speed == 0 :
339
- unit = "draws/s"
340
- else :
341
- unit = "s/draws"
342
- speed = 1 / speed
343
-
344
- return speed , unit
345
-
346
466
def update (self , chain_idx , is_last , draw , tuning , stats ):
347
467
if not self ._show_progress :
348
468
return
@@ -353,7 +473,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
353
473
chain_idx = 0
354
474
355
475
elapsed = self ._progress .tasks [chain_idx ].elapsed
356
- speed , unit = self . compute_draw_speed (elapsed , draw )
476
+ speed , unit = compute_draw_speed (elapsed , draw )
357
477
358
478
failing = False
359
479
all_step_stats = {}
@@ -395,31 +515,3 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
395
515
** all_step_stats ,
396
516
refresh = True ,
397
517
)
398
-
399
- def create_progress_bar (self , step_columns , progressbar , progressbar_theme ):
400
- columns = [TextColumn ("{task.fields[draws]}" , table_column = Column ("Draws" , ratio = 1 ))]
401
-
402
- if self .full_stats :
403
- columns += step_columns
404
-
405
- columns += [
406
- TextColumn (
407
- "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}" ,
408
- table_column = Column ("Sampling Speed" , ratio = 1 ),
409
- ),
410
- TimeElapsedColumn (table_column = Column ("Elapsed" , ratio = 1 )),
411
- TimeRemainingColumn (table_column = Column ("Remaining" , ratio = 1 )),
412
- ]
413
-
414
- return CustomProgress (
415
- RecolorOnFailureBarColumn (
416
- table_column = Column ("Progress" , ratio = 2 ),
417
- failing_color = "tab:red" ,
418
- complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
419
- finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
420
- ),
421
- * columns ,
422
- console = Console (theme = progressbar_theme ),
423
- disable = not progressbar ,
424
- include_headers = True ,
425
- )
0 commit comments