Skip to content

Commit facb54d

Browse files
committed
Break circular dependency between model_graph.py and mode/core.py with specific lazy imports
1 parent f4bdc6c commit facb54d

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

pymc/model/core.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
from pymc.logprob.basic import transformed_conditional_logp
5353
from pymc.logprob.transforms import Transform
5454
from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values
55-
from pymc.model_graph import model_to_graphviz
5655
from pymc.pytensorf import (
5756
PointFunc,
5857
SeedSequenceSeed,
@@ -437,6 +436,13 @@ def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
437436
"""Exit the context manager."""
438437
_ = MODEL_MANAGER.active_contexts.pop()
439438

439+
def _display_(self):
440+
import marimo as mo
441+
442+
from pymc.model_graph import model_to_mermaid
443+
444+
return mo.mermaid(model_to_mermaid(self))
445+
440446
@staticmethod
441447
def _validate_name(name):
442448
if name.endswith(":"):
@@ -1997,6 +2003,8 @@ def to_graphviz(
19972003
# creates the file `schools.pdf`
19982004
schools.to_graphviz().render("schools")
19992005
"""
2006+
from pymc.model_graph import model_to_graphviz
2007+
20002008
return model_to_graphviz(
20012009
model=self,
20022010
var_names=var_names,

pymc/model_graph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
from pytensor.tensor.shape import Shape
3030
from pytensor.tensor.variable import TensorVariable
3131

32-
import pymc as pm
33-
32+
from pymc.model.core import modelcontext
3433
from pymc.util import VarName, get_default_varnames, get_var_name
3534

3635
__all__ = (
@@ -662,7 +661,7 @@ def model_to_networkx(
662661
stacklevel=2,
663662
)
664663

665-
model = pm.modelcontext(model)
664+
model = modelcontext(model)
666665
graph = ModelGraph(model)
667666
return make_networkx(
668667
name=model.name,
@@ -777,7 +776,7 @@ def model_to_graphviz(
777776
stacklevel=2,
778777
)
779778

780-
model = pm.modelcontext(model)
779+
model = modelcontext(model)
781780
graph = ModelGraph(model)
782781
return make_graph(
783782
model.name,
@@ -904,7 +903,7 @@ def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool =
904903
905904
906905
"""
907-
model = pm.modelcontext(model)
906+
model = modelcontext(model)
908907
graph = ModelGraph(model)
909908
plates = sorted(graph.get_plates(var_names=var_names), key=lambda plate: hash(plate.dim_info))
910909
edges = sorted(graph.edges(var_names=var_names))

0 commit comments

Comments
 (0)