11import numpy as np
22import pytest
3- import unittest_tools as utt
43
54from pytensor import (
65 Mode ,
2524from pytensor .tensor import (
2625 add ,
2726 exp ,
28- inplace ,
2927 iscalar ,
3028 iscalars ,
3129 lscalar ,
4341from pytensor .tensor .elemwise import DimShuffle , Elemwise
4442from pytensor .tensor .rewriting .subtensor_lift import (
4543 local_subtensor_make_vector ,
44+ local_subtensor_of_elemwise ,
4645 local_subtensor_shape_constant ,
4746)
4847from pytensor .tensor .shape import SpecifyShape , _shape
5857NO_OPTIMIZATION_MODE = Mode (linker = "py" , optimizer = None )
5958
6059
61- class TestLocalSubtensorLift :
62- def test_basic (self ):
63- # basic test that the Op works
64- x = matrix ("x" )
65- f = function ([x ], exp (x )[0 ], mode = mode_opt )
66-
67- # Check stacktrace was copied over correctly after opt was applied
68- assert check_stack_trace (f , ops_to_check = "all" )
69-
70- prog = f .maker .fgraph .toposort ()
71- assert isinstance (prog [0 ].op , Subtensor ) # first subtensor
72- assert prog [1 ].op == exp
73- assert len (prog ) == 2
74- f ([[0 , 1 ], [2 , 3 ]]) # let debugmode test something
75-
76- def test_basic_1 (self ):
60+ class TestLocalSubtensorOfElemwise :
61+ def test_unary_multiple_clients (self ):
7762 # as test0, but we reuse the output of the elemwise
7863 # So we should not lift the subtensor
7964 x = matrix ("x" )
@@ -87,85 +72,16 @@ def test_basic_1(self):
8772 assert isinstance (prog [1 ].op , Subtensor ) # first subtensor
8873 assert isinstance (prog [2 ].op , DeepCopyOp )
8974 assert len (prog ) == 3
90- f ([[0 , 1 ], [2 , 3 ]]) # let debugmode test something
91-
92- def test_basic_2 (self ):
93- # basic test that the optimization work with scalar broadcasted
94- x = matrix ("x" )
95- y = scalar ("y" )
96- z = matrix ("z" )
97- f = function ([x , y , z ], exp (x + y + z )[0 ], mode = mode_opt )
98-
99- prog = f .maker .fgraph .toposort ()
100- assert isinstance (prog [0 ].op , Subtensor )
101- assert isinstance (prog [1 ].op , DimShuffle )
102- assert isinstance (prog [2 ].op , Subtensor )
103- assert isinstance (prog [3 ].op .scalar_op , ps .Composite ) # Composite{add,add}
104- assert len (prog ) == 4
105-
106- # Check stacktrace was copied over correctly after opt was applied
107- assert check_stack_trace (f , ops_to_check = [Subtensor ])
108-
109- # let debugmode test something
110- f ([[0 , 1 ], [2 , 3 ]], 4 , [[4 , 5 ], [6 , 7 ]])
111-
112- def test_basic_3 (self ):
113- # as 1, but take a slice
114- x = matrix ("x" )
115- y = scalar ("y" )
116- z = matrix ("z" )
117- f = function ([x , y , z ], exp (x + y + z )[0 :2 ], mode = mode_opt )
118-
119- prog = f .maker .fgraph .toposort ()
120- assert isinstance (prog [0 ].op , Subtensor )
121- assert isinstance (prog [1 ].op , DimShuffle )
122- assert isinstance (prog [2 ].op , Subtensor )
123- assert isinstance (prog [3 ].op .scalar_op , ps .Composite ) # Composite{add,add}
124- assert len (prog ) == 4
125-
126- # Check stacktrace was copied over correctly after opt was applied
127- assert check_stack_trace (f , ops_to_check = [Subtensor ])
128-
129- # let debugmode test something
130- f ([[0 , 1 ], [2 , 3 ]], 4 , [[4 , 5 ], [6 , 7 ]])
131-
132- def test_basic_4 (self ):
133- # basic test that the optimization does work with broadcasting
134- # for unary elemwise.
135- y = vector ("y" )
136- f = function ([y ], exp (y .dimshuffle (0 , "x" ))[0 ], mode = mode_opt )
137-
138- # Check stacktrace was copied over correctly after opt was applied
139- assert check_stack_trace (f , ops_to_check = "all" )
140-
141- prog = f .maker .fgraph .toposort ()
142- assert isinstance (prog [0 ].op , Subtensor )
143- assert isinstance (prog [1 ].op , DimShuffle )
144- assert prog [2 ].op == exp
145- assert len (prog ) == 3
146- f ([4 , 5 ]) # let debugmode test something
147-
148- @utt .assertFailure_fast
149- def test_basic_5 (self ):
150- # basic test that the optimization doesn't work with broadcasting
151- # ... It *could* be extended to,
152- # ... but right now it doesn't, so it shouldn't try.
153- x = matrix ("x" )
154- y = vector ("y" )
155- f = function ([x , y ], exp (x + y )[0 ], mode = mode_opt )
15675
157- # Opt doesn't apply, so no need for check_stack_trace
158- # assert check_stack_trace(f, ops_to_check='all')
159-
160- prog = f .maker .fgraph .toposort ()
161- assert isinstance (prog [0 ].op , DimShuffle )
162- assert prog [1 ].op == add
163- assert isinstance (prog [2 ].op , Subtensor ) # first subtensor
164- assert prog [3 ].op == inplace .exp_inplace
165- assert len (prog ) == 4
166- f ([[0 , 1 ], [2 , 3 ]], [4 , 5 ]) # let debugmode test something
76+ x_test = [[0 , 1 ], [2 , 3 ]]
77+ res1 , res2 = f (x_test )
78+ np .testing .assert_allclose (
79+ res1 ,
80+ np .exp (x_test )[0 ],
81+ )
82+ np .testing .assert_allclose (res2 , np .exp (x_test ))
16783
168- def test_basic_6 (self ):
84+ def test_multinary_multiple_clients (self ):
16985 # test that we don't lift when we reuse the output of the
17086 # elemwise for other computation.
17187 x = matrix ("x" )
@@ -181,26 +97,84 @@ def test_basic_6(self):
18197 # first subtensor
18298 assert isinstance (prog [2 ].op , Subtensor )
18399 assert len (prog ) == 3
184- f ([[0 , 1 ], [2 , 3 ]], [4 , 5 ]) # let debugmode test something
185100
186- def test_basic_7 (self ):
187- # basic test that the optimization works with a scalar as input,
188- # and a scalar as output (no broadcasting of the scalar needed).
189- # The optimization used to fail and display an ERROR message.
101+ x_test = np .array ([[0 , 1 ], [2 , 3 ]]).astype (x .dtype )
102+ y_test = np .array ([4 , 5 ]).astype (y .dtype )
103+ res1 , res2 = f (x_test , y_test )
104+ np .testing .assert_allclose (
105+ res1 ,
106+ np .exp (x_test + y_test )[0 ],
107+ )
108+ np .testing .assert_allclose (
109+ res2 ,
110+ np .exp (x_test + y_test ) + x_test ,
111+ )
112+
113+ @pytest .mark .parametrize (
114+ "original_fn, expected_fn" ,
115+ [
116+ # Unary integer indexing
117+ (lambda x , y : exp (x )[0 ], lambda x , y : exp (x [0 ])),
118+ # Unary integer with expand_dims
119+ (lambda x , y : exp (x [:, None ])[0 ], lambda x , y : exp (x [0 ][None ])),
120+ # Integer indexing on non-broadcastable dimension
121+ (lambda x , y : add (x , y )[0 ], lambda x , y : add (x [0 ], y [0 ])),
122+ # Slice indexing on non-broadcastable dimension
123+ (lambda x , y : add (x , y )[1 :], lambda x , y : add (x [1 :], y [1 :])),
124+ # Integer indexing on broacastable dimension
125+ (lambda x , y : add (x [None ], y [None ])[0 ], lambda x , y : add (x , y )),
126+ (lambda x , y : add (x [None ], y [None ])[0 , 1 ], lambda x , y : add (x [1 ], y [1 ])),
127+ (
128+ lambda x , y : add (x [None , :], y [:, None ])[2 ],
129+ lambda x , y : add (x , y [2 ][None ]),
130+ ),
131+ (
132+ lambda x , y : add (x [:, None ], y [None , :])[:, 2 ],
133+ lambda x , y : add (x , y [2 ][None ]),
134+ ),
135+ # Slice indexing on broadcastable dimension
136+ (
137+ lambda x , y : add (x [None ], y [None ])[1 :],
138+ lambda x , y : add (x [None ][1 :], y [None ][1 :]),
139+ ),
140+ (
141+ lambda x , y : add (x [None , :], y [:, None ])[1 :],
142+ lambda x , y : add (x [None , :], y [1 :][:, None ]),
143+ ),
144+ ],
145+ )
146+ def test_local_subtensor_of_elemwise (self , original_fn , expected_fn ):
147+ rng = np .random .default_rng (257 )
148+ x = pt .matrix ("x" , shape = (5 , 3 ))
149+ y = pt .matrix ("y" , shape = (5 , 3 ))
150+ x_test = rng .normal (size = x .type .shape ).astype (x .dtype )
151+ y_test = rng .normal (size = y .type .shape ).astype (y .dtype )
152+
153+ out = original_fn (x , y )
154+ expected_opt_out = expected_fn (x , y )
155+ opt_out = rewrite_graph (out )
156+ assert equal_computations ([opt_out ], [expected_opt_out ]), debugprint (
157+ [expected_opt_out , opt_out ], print_type = True
158+ )
159+ eval_kwargs = dict (mode = NO_OPTIMIZATION_MODE , on_unused_input = "ignore" )
160+ np .testing .assert_allclose (
161+ opt_out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
162+ out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
163+ )
190164
191- x = vector ("x" )
192- y = scalar ("y" )
193- f = function ([x , y ], exp (x + y )[0 ], mode = mode_opt )
165+ def test_local_subtensor_of_elemwise_multiple_clients (self ):
166+ x = pt .matrix ("x" , shape = (5 , 3 ))
167+ y = pt .matrix ("y" , shape = (5 , 3 ))
168+ out1 = add (x , y )
169+ out2 = out1 [0 ]
194170
195- # Check stacktrace was copied over correctly after opt was applied
196- assert check_stack_trace (f , ops_to_check = Subtensor )
171+ # Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
172+ fgraph = FunctionGraph ([x , y ], [out1 , out2 ], clone = False )
173+ assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is None
197174
198- prog = f .maker .fgraph .toposort ()
199- assert isinstance (prog [0 ].op , Subtensor )
200- # Composite{add,exp}
201- assert isinstance (prog [1 ].op .scalar_op , ps .Composite )
202- assert len (prog ) == 2
203- f ([1 , 2 , 3 ], 4 ) # let debugmode test something
175+ # Otherwise it should work
176+ fgraph .remove_output (0 )
177+ assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is not None
204178
205179
206180@pytest .mark .parametrize (
0 commit comments