Skip to content

Commit f6a502a

Browse files
committed
ModelGraph only stopsancestry check at model defined named_vars
1 parent a7f361b commit f6a502a

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

pymc/model_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,11 @@ def _filter_non_parameter_inputs(var):
6262
# Otherwise return all inputs
6363
return node.inputs
6464

65+
blockers = set(self.model.named_vars)
66+
6567
def _expand(x):
66-
if x.name:
68+
nonlocal blockers
69+
if x.name in blockers:
6770
return [x]
6871
if isinstance(x.owner, Apply):
6972
return reversed(_filter_non_parameter_inputs(x))

pymc/tests/test_model_graph.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,18 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
407407

408408
class TestModelNonRandomVariableRVs(BaseModelGraphTest):
409409
model_func = model_non_random_variable_rvs
410+
411+
412+
def test_model_graph_with_intermediate_named_variables():
413+
# Issue 6421
414+
with pm.Model() as m1:
415+
a = pm.Normal("a", 0, 1, shape=3)
416+
pm.Normal("b", a.mean(axis=-1), 1)
417+
assert dict(ModelGraph(m1).make_compute_graph()) == {"a": set(), "b": {"a"}}
418+
419+
with pm.Model() as m2:
420+
a = pm.Normal("a", 0, 1)
421+
b = a + 1
422+
b.name = "b"
423+
pm.Normal("c", b, 1)
424+
assert dict(ModelGraph(m2).make_compute_graph()) == {"a": set(), "c": {"a"}}

0 commit comments

Comments
 (0)