@@ -267,6 +267,58 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
267
267
"""
268
268
269
269
270
+ def in_marimo_notebook () -> bool :
271
+ try :
272
+ import marimo as mo
273
+
274
+ return mo .running_in_notebook ()
275
+ except ImportError :
276
+ return False
277
+
278
+
279
+ def _mo_write_internal (cell_id , stream , value : object ) -> None :
280
+ """Write to marimo cell given cell_id and stream."""
281
+ from marimo ._output import formatting
282
+ from marimo ._messaging .ops import CellOp
283
+ from marimo ._messaging .tracebacks import write_traceback
284
+ from marimo ._messaging .cell_output import CellChannel
285
+
286
+ output = formatting .try_format (value )
287
+ if output .traceback is not None :
288
+ write_traceback (output .traceback )
289
+ CellOp .broadcast_output (
290
+ channel = CellChannel .OUTPUT ,
291
+ mimetype = output .mimetype ,
292
+ data = output .data ,
293
+ cell_id = cell_id ,
294
+ status = None ,
295
+ stream = stream ,
296
+ )
297
+
298
+
299
+ def _mo_create_replace ():
300
+ """Create mo.output.replace with current context pinned."""
301
+ from marimo ._runtime .context import get_context
302
+ from marimo ._runtime .context .types import ContextNotInitializedError
303
+ from marimo ._output import formatting
304
+
305
+ try :
306
+ ctx = get_context ()
307
+ except ContextNotInitializedError :
308
+ return
309
+
310
+ cell_id = ctx .execution_context .cell_id
311
+ execution_context = ctx .execution_context
312
+ stream = ctx .stream
313
+
314
+ def replace (value ):
315
+ execution_context .output = [formatting .as_html (value )]
316
+
317
+ _mo_write_internal (cell_id = cell_id , value = value , stream = stream )
318
+
319
+ return replace
320
+
321
+
270
322
# Adapted from fastprogress
271
323
def in_notebook ():
272
324
def in_colab ():
@@ -362,6 +414,28 @@ def callback(formatted):
362
414
self ._html = formatted
363
415
self .display_id .update (self )
364
416
417
+ progress_type = _lib .ProgressType .template_callback (
418
+ progress_rate , progress_template , cores , callback
419
+ )
420
+ elif in_marimo_notebook ():
421
+ import marimo as mo
422
+
423
+ if progress_template is None :
424
+ progress_template = _progress_template
425
+
426
+ if progress_style is None :
427
+ progress_style = _progress_style
428
+
429
+ self ._html = ""
430
+
431
+ mo .output .clear ()
432
+ mo_output_replace = _mo_create_replace ()
433
+
434
+ def callback (formatted ):
435
+ self ._html = formatted
436
+ html = mo .Html (f"{ progress_style } \n { formatted } " )
437
+ mo_output_replace (html )
438
+
365
439
progress_type = _lib .ProgressType .template_callback (
366
440
progress_rate , progress_template , cores , callback
367
441
)
0 commit comments