Skip to content

Commit 5ad9c14

Browse files
authored
feat: add support for marimo notebooks (#224)
1 parent 53be4e0 commit 5ad9c14

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

python/nutpie/sample.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,58 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
267267
"""
268268

269269

270+
def in_marimo_notebook() -> bool:
271+
try:
272+
import marimo as mo
273+
274+
return mo.running_in_notebook()
275+
except ImportError:
276+
return False
277+
278+
279+
def _mo_write_internal(cell_id, stream, value: object) -> None:
280+
"""Write to marimo cell given cell_id and stream."""
281+
from marimo._output import formatting
282+
from marimo._messaging.ops import CellOp
283+
from marimo._messaging.tracebacks import write_traceback
284+
from marimo._messaging.cell_output import CellChannel
285+
286+
output = formatting.try_format(value)
287+
if output.traceback is not None:
288+
write_traceback(output.traceback)
289+
CellOp.broadcast_output(
290+
channel=CellChannel.OUTPUT,
291+
mimetype=output.mimetype,
292+
data=output.data,
293+
cell_id=cell_id,
294+
status=None,
295+
stream=stream,
296+
)
297+
298+
299+
def _mo_create_replace():
300+
"""Create mo.output.replace with current context pinned."""
301+
from marimo._runtime.context import get_context
302+
from marimo._runtime.context.types import ContextNotInitializedError
303+
from marimo._output import formatting
304+
305+
try:
306+
ctx = get_context()
307+
except ContextNotInitializedError:
308+
return
309+
310+
cell_id = ctx.execution_context.cell_id
311+
execution_context = ctx.execution_context
312+
stream = ctx.stream
313+
314+
def replace(value):
315+
execution_context.output = [formatting.as_html(value)]
316+
317+
_mo_write_internal(cell_id=cell_id, value=value, stream=stream)
318+
319+
return replace
320+
321+
270322
# Adapted from fastprogress
271323
def in_notebook():
272324
def in_colab():
@@ -362,6 +414,28 @@ def callback(formatted):
362414
self._html = formatted
363415
self.display_id.update(self)
364416

417+
progress_type = _lib.ProgressType.template_callback(
418+
progress_rate, progress_template, cores, callback
419+
)
420+
elif in_marimo_notebook():
421+
import marimo as mo
422+
423+
if progress_template is None:
424+
progress_template = _progress_template
425+
426+
if progress_style is None:
427+
progress_style = _progress_style
428+
429+
self._html = ""
430+
431+
mo.output.clear()
432+
mo_output_replace = _mo_create_replace()
433+
434+
def callback(formatted):
435+
self._html = formatted
436+
html = mo.Html(f"{progress_style}\n{formatted}")
437+
mo_output_replace(html)
438+
365439
progress_type = _lib.ProgressType.template_callback(
366440
progress_rate, progress_template, cores, callback
367441
)

0 commit comments

Comments
 (0)