Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values
from pymc.model_graph import model_to_graphviz, model_to_mermaid
from pymc.pytensorf import (
PointFunc,
SeedSequenceSeed,
Expand Down Expand Up @@ -440,6 +439,8 @@
def _display_(self):
import marimo as mo

from pymc.model_graph import model_to_mermaid

Check warning on line 442 in pymc/model/core.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/core.py#L442

Added line #L442 was not covered by tests

return mo.mermaid(model_to_mermaid(self))

@staticmethod
Expand Down Expand Up @@ -2002,6 +2003,8 @@
# creates the file `schools.pdf`
schools.to_graphviz().render("schools")
"""
from pymc.model_graph import model_to_graphviz

return model_to_graphviz(
model=self,
var_names=var_names,
Expand Down
106 changes: 42 additions & 64 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,11 @@
from typing import Any, cast

from pytensor import function
from pytensor.graph import Apply
from pytensor.graph.basic import ancestors, walk
from pytensor.scalar.basic import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.shape import Shape
from pytensor.tensor.variable import TensorVariable

import pymc as pm

from pymc.model.core import modelcontext
from pymc.util import VarName, get_default_varnames, get_var_name

__all__ = (
Expand Down Expand Up @@ -241,42 +236,32 @@ class ModelGraph:
def __init__(self, model):
self.model = model
self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
self._all_vars = {model[var_name] for var_name in self._all_var_names}
self.var_list = self.model.named_vars.values()

def get_parent_names(self, var: TensorVariable) -> set[VarName]:
if var.owner is None or var.owner.inputs is None:
if var.owner is None:
return set()

def _filter_non_parameter_inputs(var):
node = var.owner
if isinstance(node.op, Shape):
# Don't show shape-related dependencies
return []
if isinstance(node.op, RandomVariable):
# Filter out rng and size parameters or RandomVariable nodes
return node.op.dist_params(node)
else:
# Otherwise return all inputs
return node.inputs

blockers = set(self.model.named_vars)
named_vars = self._all_vars

def _expand(x):
nonlocal blockers
if x.name in blockers:
if x in named_vars:
# Don't go beyond named_vars
return [x]
if isinstance(x.owner, Apply):
return reversed(_filter_non_parameter_inputs(x))
return []

parents = set()
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand):
# Only consider nodes that are in the named model variables.
vname = getattr(x, "name", None)
if isinstance(vname, str) and vname in self._all_var_names:
parents.add(VarName(vname))
if x.owner is None:
return []
if isinstance(x.owner.op, Shape):
# Don't propagate shape-related dependencies
return []
# Continue walking the graph through the inputs
return x.owner.inputs

return parents
return {
cast(VarName, ancestor.name) # type: ignore[union-attr]
for ancestor in walk(nodes=var.owner.inputs, expand=_expand)
if ancestor in named_vars
}

def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]:
if var_names is None:
Expand Down Expand Up @@ -312,35 +297,28 @@ def make_compute_graph(
self, var_names: Iterable[VarName] | None = None
) -> dict[VarName, set[VarName]]:
"""Get map of var_name -> set(input var names) for the model."""
model = self.model
named_vars = self._all_vars
input_map: dict[VarName, set[VarName]] = defaultdict(set)

for var_name in self.vars_to_plot(var_names):
var = self.model[var_name]
parent_name = self.get_parent_names(var)
input_map[var_name] = input_map[var_name].union(parent_name)

if var in self.model.observed_RVs:
obs_node = self.model.rvs_to_values[var]

# loop created so that the elif block can go through this again
# and remove any intermediate ops, notably dtype casting, to observations
while True:
obs_name = obs_node.name
if obs_name and obs_name != var_name:
input_map[var_name] = input_map[var_name].difference({obs_name})
input_map[obs_name] = input_map[obs_name].union({var_name})
break
elif (
# for cases where observations are cast to a certain dtype
# see issue 5795: https://github.com/pymc-devs/pymc/issues/5795
obs_node.owner
and isinstance(obs_node.owner.op, Elemwise)
and isinstance(obs_node.owner.op.scalar_op, Cast)
):
# we can retrieve the observation node by going up the graph
obs_node = obs_node.owner.inputs[0]
else:
break
var_names_to_plot = self.vars_to_plot(var_names)
for var_name in var_names_to_plot:
parent_names = self.get_parent_names(model[var_name])
input_map[var_name].update(parent_names)

for var_name in var_names_to_plot:
if (var := model[var_name]) in model.observed_RVs:
# Make observed `Data` variables flow from the observed RV, and not the other way around
# (In the generative graph they usually inform shape of the observed RV)
# We have to iterate over the ancestors of the observed values because there can be
# deterministic operations in between the `Data` variable and the observed value.
obs_var = model.rvs_to_values[var]
for ancestor in ancestors([obs_var]):
if ancestor not in named_vars:
continue
obs_name = cast(VarName, ancestor.name)
input_map[var_name].discard(obs_name)
input_map[obs_name].add(var_name)

return input_map

Expand All @@ -361,7 +339,7 @@ def get_plates(
plates = defaultdict(set)

# TODO: Evaluate all RV shapes at once
# This should help find discrepencies, and
# This should help find discrepancies, and
# avoids unnecessary function compiles for determining labels.
dim_lengths: dict[str, int] = {
dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items()
Expand Down Expand Up @@ -662,7 +640,7 @@ def model_to_networkx(
stacklevel=2,
)

model = pm.modelcontext(model)
model = modelcontext(model)
graph = ModelGraph(model)
return make_networkx(
name=model.name,
Expand Down Expand Up @@ -777,7 +755,7 @@ def model_to_graphviz(
stacklevel=2,
)

model = pm.modelcontext(model)
model = modelcontext(model)
graph = ModelGraph(model)
return make_graph(
model.name,
Expand Down Expand Up @@ -910,7 +888,7 @@ def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool =


"""
model = pm.modelcontext(model)
model = modelcontext(model)
graph = ModelGraph(model)
plates = sorted(graph.get_plates(var_names=var_names), key=lambda plate: hash(plate.dim_info))
edges = sorted(graph.edges(var_names=var_names))
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]):
for section, sdf in df.reset_index().groupby(args.groupby):
print(f"\n\n[{section}]")
for row in sdf.itertuples():
print(f"{row.file}:{row.line}: {row.type}: {row.message}")
print(f"{row.file}:{row.line}: {row.type} [{row.errorcode}]: {row.message}")
print()
else:
print(
Expand Down
2 changes: 1 addition & 1 deletion tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ def school_model(J: int) -> pm.Model:
)
def test_graphviz_call_function(self, var_names, filenames) -> None:
model = self.school_model(J=8)
with patch("pymc.model.core.model_to_graphviz") as mock_model_to_graphviz:
with patch("pymc.model_graph.model_to_graphviz") as mock_model_to_graphviz:
model.to_graphviz(var_names=var_names, save=filenames)
mock_model_to_graphviz.assert_called_once_with(
model=model,
Expand Down
29 changes: 27 additions & 2 deletions tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,14 +516,39 @@ def test_model_graph_with_intermediate_named_variables():
with pm.Model() as m1:
a = pm.Normal("a", 0, 1, shape=3)
pm.Normal("b", a.mean(axis=-1), 1)
assert dict(ModelGraph(m1).make_compute_graph()) == {"a": set(), "b": {"a"}}
assert ModelGraph(m1).make_compute_graph() == {"a": set(), "b": {"a"}}

with pm.Model() as m2:
a = pm.Normal("a", 0, 1)
b = a + 1
b.name = "b"
pm.Normal("c", b, 1)
assert dict(ModelGraph(m2).make_compute_graph()) == {"a": set(), "c": {"a"}}
assert ModelGraph(m2).make_compute_graph() == {"a": set(), "c": {"a"}}

# Regression test for https://github.com/pymc-devs/pymc/issues/7397
with pm.Model() as m3:
data = pt.as_tensor_variable(
np.ones((5, 3)),
name="C",
)
# C has the same name as `data` variable
# This used to be wrongly picked up as a dependency
C = pm.Deterministic("C", data)
# D depends on a variable called `C` but this is not really one in the model
D = pm.Deterministic("D", data)
# This actually depends on the model variable `C`
E = pm.Deterministic("E", C)
assert ModelGraph(m3).make_compute_graph() == {"C": set(), "D": set(), "E": {"C"}}


def test_model_graph_complex_observed_dependency():
with pm.Model() as model:
x = pm.Data("x", [0])
y = pm.Data("y", [0])
observed = pt.exp(x) + pt.log(y)
pm.Normal("obs", mu=0, observed=observed)

assert ModelGraph(model).make_compute_graph() == {"obs": set(), "x": {"obs"}, "y": {"obs"}}


@pytest.fixture
Expand Down