Skip to content

Commit 4984ef3

Browse files
committed
Do not set rng updates for RandomVariables that are not computed in the graph
1 parent 73c9e3c commit 4984ef3

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

pymc/aesaraf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
Apply,
4040
Constant,
4141
Variable,
42-
ancestors,
4342
clone_get_equiv,
4443
graph_inputs,
44+
vars_between,
4545
walk,
4646
)
4747
from aesara.graph.fg import FunctionGraph
@@ -975,8 +975,8 @@ def compile_pymc(
975975
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
976976
for rv in (
977977
node
978-
for node in ancestors(output_to_list)
979-
if node.owner and isinstance(node.owner.op, RandomVariable)
978+
for node in vars_between(inputs, output_to_list)
979+
if node.owner and isinstance(node.owner.op, RandomVariable) and node not in inputs
980980
):
981981
rng = rv.owner.inputs[0]
982982
if not hasattr(rng, "default_update"):

pymc/tests/test_aesaraf.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,29 @@ def test_compile_pymc_missing_default_explicit_updates():
636636
# And again, it should be overridden by an explicit update
637637
f = compile_pymc([], x, updates={rng: x.owner.outputs[0]})
638638
assert f() != f()
639+
640+
641+
def test_compile_pymc_updates_inputs():
642+
"""Test that compile_pymc does not include rngs updates of variables that are inputs
643+
or ancestors to inputs
644+
"""
645+
x = at.random.normal()
646+
y = at.random.normal(x)
647+
z = at.random.normal(y)
648+
649+
for inputs, rvs_in_graph in (
650+
([], 3),
651+
([x], 2),
652+
([y], 1),
653+
([z], 0),
654+
([x, y], 1),
655+
([x, y, z], 0),
656+
):
657+
fn = compile_pymc(inputs, z, on_unused_input="ignore")
658+
fn_fgraph = fn.maker.fgraph
659+
# Each RV adds a shared input for its rng
660+
assert len(fn_fgraph.inputs) == len(inputs) + rvs_in_graph
661+
# If the output is an input, the graph has a DeepCopyOp
662+
assert len(fn_fgraph.apply_nodes) == max(rvs_in_graph, 1)
663+
# Each RV adds a shared output for its rng
664+
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph

0 commit comments

Comments
 (0)