diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 27d53c8687..46eb99b7c7 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -60,7 +60,7 @@ convert_observed_data, floatX, ) -from pymc.util import UNSET +from pymc.util import UNSET, safe_display from pymc.vartypes import continuous_types, string_types __all__ = [ @@ -544,6 +544,10 @@ def __new__( rv_out._repr_latex_ = types.MethodType( functools.partial(str_for_dist, formatting="latex"), rv_out ) + + # https://docs.marimo.io/guides/integrating_with_marimo/displaying_objects/#option-1-implement-a-_display_-method + rv_out._display_ = types.MethodType(functools.partial(safe_display, model=model), rv_out) + return rv_out @classmethod diff --git a/pymc/model/core.py b/pymc/model/core.py index 69e1fbed72..2bfd5889bd 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -72,6 +72,7 @@ get_transformed_name, get_value_vars_from_user_vars, get_var_name, + safe_display, treedict, treelist, ) @@ -2280,6 +2281,7 @@ def Deterministic(name, var, model=None, dims=None): ), var, ) + var._display_ = types.MethodType(functools.partial(safe_display, model=model), var) return var @@ -2408,5 +2410,6 @@ def normal_logp(value, mu, sigma): ), var, ) + var._display_ = types.MethodType(functools.partial(safe_display, model=model), var) return var diff --git a/pymc/util.py b/pymc/util.py index 32d8d65e70..3f92cad262 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -570,3 +570,11 @@ def get_random_generator( return random_generator_from_state(get_state_from_generator(seed)) seed = deepcopy(seed) return np.random.default_rng(seed) + + +def safe_display(variable, model): + import marimo as mo + + from pymc.model_graph import model_to_mermaid + + return mo.mermaid(model_to_mermaid(model=model, var_names=[variable.name]))