Skip to content

Commit 0a5e0c5

Browse files
committed
Break circular dependency between model_graph.py and mode/core.py with specific lazy imports
1 parent 718c8e4 commit 0a5e0c5

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

pymc/model/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
from pymc.logprob.basic import transformed_conditional_logp
5353
from pymc.logprob.transforms import Transform
5454
from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values
55-
from pymc.model_graph import model_to_graphviz, model_to_mermaid
5655
from pymc.pytensorf import (
5756
PointFunc,
5857
SeedSequenceSeed,
@@ -440,6 +439,8 @@ def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
440439
def _display_(self):
441440
import marimo as mo
442441

442+
from pymc.model_graph import model_to_mermaid
443+
443444
return mo.mermaid(model_to_mermaid(self))
444445

445446
@staticmethod
@@ -2002,6 +2003,8 @@ def to_graphviz(
20022003
# creates the file `schools.pdf`
20032004
schools.to_graphviz().render("schools")
20042005
"""
2006+
from pymc.model_graph import model_to_graphviz
2007+
20052008
return model_to_graphviz(
20062009
model=self,
20072010
var_names=var_names,

pymc/model_graph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
from pytensor.tensor.shape import Shape
3030
from pytensor.tensor.variable import TensorVariable
3131

32-
import pymc as pm
33-
32+
from pymc.model.core import modelcontext
3433
from pymc.util import VarName, get_default_varnames, get_var_name
3534

3635
__all__ = (
@@ -662,7 +661,7 @@ def model_to_networkx(
662661
stacklevel=2,
663662
)
664663

665-
model = pm.modelcontext(model)
664+
model = modelcontext(model)
666665
graph = ModelGraph(model)
667666
return make_networkx(
668667
name=model.name,
@@ -777,7 +776,7 @@ def model_to_graphviz(
777776
stacklevel=2,
778777
)
779778

780-
model = pm.modelcontext(model)
779+
model = modelcontext(model)
781780
graph = ModelGraph(model)
782781
return make_graph(
783782
model.name,
@@ -910,7 +909,7 @@ def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool =
910909
911910
912911
"""
913-
model = pm.modelcontext(model)
912+
model = modelcontext(model)
914913
graph = ModelGraph(model)
915914
plates = sorted(graph.get_plates(var_names=var_names), key=lambda plate: hash(plate.dim_info))
916915
edges = sorted(graph.edges(var_names=var_names))

tests/model/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1750,7 +1750,7 @@ def school_model(J: int) -> pm.Model:
17501750
)
17511751
def test_graphviz_call_function(self, var_names, filenames) -> None:
17521752
model = self.school_model(J=8)
1753-
with patch("pymc.model.core.model_to_graphviz") as mock_model_to_graphviz:
1753+
with patch("pymc.model_graph.model_to_graphviz") as mock_model_to_graphviz:
17541754
model.to_graphviz(var_names=var_names, save=filenames)
17551755
mock_model_to_graphviz.assert_called_once_with(
17561756
model=model,

0 commit comments

Comments
 (0)