11import numpy as np
22import pytest
3- import unittest_tools as utt
43
54from pytensor import (
65 Mode ,
2524from pytensor .tensor import (
2625 add ,
2726 exp ,
28- inplace ,
2927 iscalar ,
3028 iscalars ,
3129 lscalar ,
3230 lscalars ,
3331 matrix ,
34- scalar ,
3532 shape ,
3633 slicetype ,
3734 specify_shape ,
4340from pytensor .tensor .elemwise import DimShuffle , Elemwise
4441from pytensor .tensor .rewriting .subtensor_lift import (
4542 local_subtensor_make_vector ,
43+ local_subtensor_of_elemwise ,
4644 local_subtensor_shape_constant ,
4745)
4846from pytensor .tensor .shape import SpecifyShape , _shape
5856NO_OPTIMIZATION_MODE = Mode (linker = "py" , optimizer = None )
5957
6058
61- class TestLocalSubtensorLift :
62- def test_basic (self ):
63- # basic test that the Op works
64- x = matrix ("x" )
65- f = function ([x ], exp (x )[0 ], mode = mode_opt )
66-
67- # Check stacktrace was copied over correctly after opt was applied
68- assert check_stack_trace (f , ops_to_check = "all" )
69-
70- prog = f .maker .fgraph .toposort ()
71- assert isinstance (prog [0 ].op , Subtensor ) # first subtensor
72- assert prog [1 ].op == exp
73- assert len (prog ) == 2
74- f ([[0 , 1 ], [2 , 3 ]]) # let debugmode test something
75-
76- def test_basic_1 (self ):
59+ class TestLocalSubtensorOfElemwise :
60+ def test_unary_multiple_clients (self ):
7761 # as test0, but we reuse the output of the elemwise
7862 # So we should not lift the subtensor
7963 x = matrix ("x" )
@@ -87,85 +71,16 @@ def test_basic_1(self):
8771 assert isinstance (prog [1 ].op , Subtensor ) # first subtensor
8872 assert isinstance (prog [2 ].op , DeepCopyOp )
8973 assert len (prog ) == 3
90- f ([[0 , 1 ], [2 , 3 ]]) # let debugmode test something
91-
92- def test_basic_2 (self ):
93- # basic test that the optimization work with scalar broadcasted
94- x = matrix ("x" )
95- y = scalar ("y" )
96- z = matrix ("z" )
97- f = function ([x , y , z ], exp (x + y + z )[0 ], mode = mode_opt )
98-
99- prog = f .maker .fgraph .toposort ()
100- assert isinstance (prog [0 ].op , Subtensor )
101- assert isinstance (prog [1 ].op , DimShuffle )
102- assert isinstance (prog [2 ].op , Subtensor )
103- assert isinstance (prog [3 ].op .scalar_op , ps .Composite ) # Composite{add,add}
104- assert len (prog ) == 4
105-
106- # Check stacktrace was copied over correctly after opt was applied
107- assert check_stack_trace (f , ops_to_check = [Subtensor ])
108-
109- # let debugmode test something
110- f ([[0 , 1 ], [2 , 3 ]], 4 , [[4 , 5 ], [6 , 7 ]])
111-
112- def test_basic_3 (self ):
113- # as 1, but take a slice
114- x = matrix ("x" )
115- y = scalar ("y" )
116- z = matrix ("z" )
117- f = function ([x , y , z ], exp (x + y + z )[0 :2 ], mode = mode_opt )
118-
119- prog = f .maker .fgraph .toposort ()
120- assert isinstance (prog [0 ].op , Subtensor )
121- assert isinstance (prog [1 ].op , DimShuffle )
122- assert isinstance (prog [2 ].op , Subtensor )
123- assert isinstance (prog [3 ].op .scalar_op , ps .Composite ) # Composite{add,add}
124- assert len (prog ) == 4
125-
126- # Check stacktrace was copied over correctly after opt was applied
127- assert check_stack_trace (f , ops_to_check = [Subtensor ])
128-
129- # let debugmode test something
130- f ([[0 , 1 ], [2 , 3 ]], 4 , [[4 , 5 ], [6 , 7 ]])
131-
132- def test_basic_4 (self ):
133- # basic test that the optimization does work with broadcasting
134- # for unary elemwise.
135- y = vector ("y" )
136- f = function ([y ], exp (y .dimshuffle (0 , "x" ))[0 ], mode = mode_opt )
137-
138- # Check stacktrace was copied over correctly after opt was applied
139- assert check_stack_trace (f , ops_to_check = "all" )
140-
141- prog = f .maker .fgraph .toposort ()
142- assert isinstance (prog [0 ].op , Subtensor )
143- assert isinstance (prog [1 ].op , DimShuffle )
144- assert prog [2 ].op == exp
145- assert len (prog ) == 3
146- f ([4 , 5 ]) # let debugmode test something
147-
148- @utt .assertFailure_fast
149- def test_basic_5 (self ):
150- # basic test that the optimization doesn't work with broadcasting
151- # ... It *could* be extended to,
152- # ... but right now it doesn't, so it shouldn't try.
153- x = matrix ("x" )
154- y = vector ("y" )
155- f = function ([x , y ], exp (x + y )[0 ], mode = mode_opt )
15674
157- # Opt doesn't apply, so no need for check_stack_trace
158- # assert check_stack_trace(f, ops_to_check='all')
159-
160- prog = f .maker .fgraph .toposort ()
161- assert isinstance (prog [0 ].op , DimShuffle )
162- assert prog [1 ].op == add
163- assert isinstance (prog [2 ].op , Subtensor ) # first subtensor
164- assert prog [3 ].op == inplace .exp_inplace
165- assert len (prog ) == 4
166- f ([[0 , 1 ], [2 , 3 ]], [4 , 5 ]) # let debugmode test something
75+ x_test = [[0 , 1 ], [2 , 3 ]]
76+ res1 , res2 = f (x_test )
77+ np .testing .assert_allclose (
78+ res1 ,
79+ np .exp (x_test )[0 ],
80+ )
81+ np .testing .assert_allclose (res2 , np .exp (x_test ))
16782
168- def test_basic_6 (self ):
83+ def test_multinary_multiple_clients (self ):
16984 # test that we don't lift when we reuse the output of the
17085 # elemwise for other computation.
17186 x = matrix ("x" )
@@ -181,26 +96,84 @@ def test_basic_6(self):
18196 # first subtensor
18297 assert isinstance (prog [2 ].op , Subtensor )
18398 assert len (prog ) == 3
184- f ([[0 , 1 ], [2 , 3 ]], [4 , 5 ]) # let debugmode test something
18599
186- def test_basic_7 (self ):
187- # basic test that the optimization works with a scalar as input,
188- # and a scalar as output (no broadcasting of the scalar needed).
189- # The optimization used to fail and display an ERROR message.
100+ x_test = np .array ([[0 , 1 ], [2 , 3 ]]).astype (x .dtype )
101+ y_test = np .array ([4 , 5 ]).astype (y .dtype )
102+ res1 , res2 = f (x_test , y_test )
103+ np .testing .assert_allclose (
104+ res1 ,
105+ np .exp (x_test + y_test )[0 ],
106+ )
107+ np .testing .assert_allclose (
108+ res2 ,
109+ np .exp (x_test + y_test ) + x_test ,
110+ )
111+
112+ @pytest .mark .parametrize (
113+ "original_fn, expected_fn" ,
114+ [
115+ # Unary integer indexing
116+ (lambda x , y : exp (x )[0 ], lambda x , y : exp (x [0 ])),
117+ # Unary integer with expand_dims
118+ (lambda x , y : exp (x [:, None ])[0 ], lambda x , y : exp (x [0 ][None ])),
119+ # Integer indexing on non-broadcastable dimension
120+ (lambda x , y : add (x , y )[0 ], lambda x , y : add (x [0 ], y [0 ])),
121+ # Slice indexing on non-broadcastable dimension
122+ (lambda x , y : add (x , y )[1 :], lambda x , y : add (x [1 :], y [1 :])),
123+ # Integer indexing on broacastable dimension
124+ (lambda x , y : add (x [None ], y [None ])[0 ], lambda x , y : add (x , y )),
125+ (lambda x , y : add (x [None ], y [None ])[0 , 1 ], lambda x , y : add (x [1 ], y [1 ])),
126+ (
127+ lambda x , y : add (x [None , :], y [:, None ])[2 ],
128+ lambda x , y : add (x , y [2 ][None ]),
129+ ),
130+ (
131+ lambda x , y : add (x [:, None ], y [None , :])[:, 2 ],
132+ lambda x , y : add (x , y [2 ][None ]),
133+ ),
134+ # Slice indexing on broadcastable dimension
135+ (
136+ lambda x , y : add (x [None ], y [None ])[1 :],
137+ lambda x , y : add (x [None ][1 :], y [None ][1 :]),
138+ ),
139+ (
140+ lambda x , y : add (x [None , :], y [:, None ])[1 :],
141+ lambda x , y : add (x [None , :], y [1 :][:, None ]),
142+ ),
143+ ],
144+ )
145+ def test_local_subtensor_of_elemwise (self , original_fn , expected_fn ):
146+ rng = np .random .default_rng (257 )
147+ x = pt .matrix ("x" , shape = (5 , 3 ))
148+ y = pt .matrix ("y" , shape = (5 , 3 ))
149+ x_test = rng .normal (size = x .type .shape ).astype (x .dtype )
150+ y_test = rng .normal (size = y .type .shape ).astype (y .dtype )
151+
152+ out = original_fn (x , y )
153+ expected_opt_out = expected_fn (x , y )
154+ opt_out = rewrite_graph (out )
155+ assert equal_computations ([opt_out ], [expected_opt_out ]), debugprint (
156+ [expected_opt_out , opt_out ], print_type = True
157+ )
158+ eval_kwargs = dict (mode = NO_OPTIMIZATION_MODE , on_unused_input = "ignore" )
159+ np .testing .assert_allclose (
160+ opt_out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
161+ out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
162+ )
190163
191- x = vector ("x" )
192- y = scalar ("y" )
193- f = function ([x , y ], exp (x + y )[0 ], mode = mode_opt )
164+ def test_local_subtensor_of_elemwise_multiple_clients (self ):
165+ x = pt .matrix ("x" , shape = (5 , 3 ))
166+ y = pt .matrix ("y" , shape = (5 , 3 ))
167+ out1 = add (x , y )
168+ out2 = out1 [0 ]
194169
195- # Check stacktrace was copied over correctly after opt was applied
196- assert check_stack_trace (f , ops_to_check = Subtensor )
170+ # Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
171+ fgraph = FunctionGraph ([x , y ], [out1 , out2 ], clone = False )
172+ assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is None
197173
198- prog = f .maker .fgraph .toposort ()
199- assert isinstance (prog [0 ].op , Subtensor )
200- # Composite{add,exp}
201- assert isinstance (prog [1 ].op .scalar_op , ps .Composite )
202- assert len (prog ) == 2
203- f ([1 , 2 , 3 ], 4 ) # let debugmode test something
174+ # Otherwise it should work
175+ fgraph .remove_output (0 )
176+ assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is not None
204177
205178
206179@pytest .mark .parametrize (
0 commit comments