@@ -835,6 +835,39 @@ def test_OpFromGraph():
835835 compare_numba_and_py ([x , y , z ], [out ], [xv , yv , zv ])
836836
837837
838+ @pytest .mark .filterwarnings ("error" )
839+ def test_ofg_inner_inplace ():
840+ x = pt .vector ("x" )
841+ set0 = x [0 ].set (1 ) # SetSubtensor should not inplace on x
842+ exp_x = pt .exp (x )
843+ set1 = exp_x [0 ].set (1 ) # SetSubtensor should inplace on exp_x
844+ ofg0 = OpFromGraph ([x ], [set0 ])
845+ ofg1 = OpFromGraph ([x ], [set1 ])
846+
847+ y , z = pt .vectors ("y" , "z" )
848+ fn = function ([y , z ], [ofg0 (y ), ofg1 (z )], mode = "NUMBA" )
849+
850+ fn_ofg0 = fn .maker .fgraph .outputs [0 ].owner .op
851+ assert isinstance (fn_ofg0 , OpFromGraph )
852+ fn_set0 = fn_ofg0 .fgraph .outputs [0 ]
853+ assert fn_set0 .owner .op .destroy_map == {}
854+
855+ fn_ofg1 = fn .maker .fgraph .outputs [1 ].owner .op
856+ assert isinstance (fn_ofg1 , OpFromGraph )
857+ fn_set1 = fn_ofg1 .fgraph .outputs [0 ]
858+ assert fn_set1 .owner .op .destroy_map == {0 : [0 ]}
859+
860+ x_test = np .array ([0 , 1 , 1 ], dtype = config .floatX )
861+ y_test = np .array ([0 , 1 , 1 ], dtype = config .floatX )
862+ res0 , res1 = fn (x_test , y_test )
863+ # Check inputs were not mutated
864+ np .testing .assert_allclose (x_test , [0 , 1 , 1 ])
865+ np .testing .assert_allclose (y_test , [0 , 1 , 1 ])
866+ # Check outputs are correct
867+ np .testing .assert_allclose (res0 , [1 , 1 , 1 ])
868+ np .testing .assert_allclose (res1 , [1 , np .e , np .e ])
869+
870+
838871@pytest .mark .filterwarnings ("error" )
839872def test_cache_warning_suppressed ():
840873 x = pt .vector ("x" , shape = (5 ,), dtype = "float64" )
0 commit comments