File tree Expand file tree Collapse file tree 2 files changed +29
-3
lines changed Expand file tree Collapse file tree 2 files changed +29
-3
lines changed Original file line number Diff line number Diff line change 39
39
Apply ,
40
40
Constant ,
41
41
Variable ,
42
- ancestors ,
43
42
clone_get_equiv ,
44
43
graph_inputs ,
44
+ vars_between ,
45
45
walk ,
46
46
)
47
47
from aesara .graph .fg import FunctionGraph
@@ -975,8 +975,8 @@ def compile_pymc(
975
975
output_to_list = outputs if isinstance (outputs , (list , tuple )) else [outputs ]
976
976
for rv in (
977
977
node
978
- for node in ancestors ( output_to_list )
979
- if node .owner and isinstance (node .owner .op , RandomVariable )
978
+ for node in vars_between ( inputs , output_to_list )
979
+ if node .owner and isinstance (node .owner .op , RandomVariable ) and node not in inputs
980
980
):
981
981
rng = rv .owner .inputs [0 ]
982
982
if not hasattr (rng , "default_update" ):
Original file line number Diff line number Diff line change @@ -636,3 +636,29 @@ def test_compile_pymc_missing_default_explicit_updates():
636
636
# And again, it should be overridden by an explicit update
637
637
f = compile_pymc ([], x , updates = {rng : x .owner .outputs [0 ]})
638
638
assert f () != f ()
639
+
640
+
641
+ def test_compile_pymc_updates_inputs ():
642
+ """Test that compile_pymc does not include rngs updates of variables that are inputs
643
+ or ancestors to inputs
644
+ """
645
+ x = at .random .normal ()
646
+ y = at .random .normal (x )
647
+ z = at .random .normal (y )
648
+
649
+ for inputs , rvs_in_graph in (
650
+ ([], 3 ),
651
+ ([x ], 2 ),
652
+ ([y ], 1 ),
653
+ ([z ], 0 ),
654
+ ([x , y ], 1 ),
655
+ ([x , y , z ], 0 ),
656
+ ):
657
+ fn = compile_pymc (inputs , z , on_unused_input = "ignore" )
658
+ fn_fgraph = fn .maker .fgraph
659
+ # Each RV adds a shared input for its rng
660
+ assert len (fn_fgraph .inputs ) == len (inputs ) + rvs_in_graph
661
+ # If the output is an input, the graph has a DeepCopyOp
662
+ assert len (fn_fgraph .apply_nodes ) == max (rvs_in_graph , 1 )
663
+ # Each RV adds a shared output for its rng
664
+ assert len (fn_fgraph .outputs ) == 1 + rvs_in_graph
You can’t perform that action at this time.
0 commit comments