1414from pytensor .graph import (
1515 Constant ,
1616 FunctionGraph ,
17+ Op ,
1718 RewriteDatabaseQuery ,
1819 Type ,
1920 rewrite_graph ,
2324from pytensor .printing import debugprint
2425from pytensor .tensor import (
2526 add ,
27+ dvector ,
2628 exp ,
2729 iscalar ,
2830 iscalars ,
3739 vector ,
3840)
3941from pytensor .tensor .basic import MakeVector , concatenate , expand_dims , make_vector
42+ from pytensor .tensor .blockwise import Blockwise
4043from pytensor .tensor .elemwise import DimShuffle , Elemwise
4144from pytensor .tensor .math import sum as pt_sum
4245from pytensor .tensor .rewriting .subtensor_lift import (
4346 local_subtensor_make_vector ,
44- local_subtensor_of_elemwise ,
47+ local_subtensor_of_batch_dims ,
4548 local_subtensor_shape_constant ,
4649)
4750from pytensor .tensor .shape import SpecifyShape , _shape
5861NO_OPTIMIZATION_MODE = Mode (linker = "py" , optimizer = None )
5962
6063
61- class TestLocalSubtensorOfElemwise :
64+ class TestLocalSubtensorOfBatchDims :
6265 def test_unary_multiple_clients (self ):
6366 # as test0, but we reuse the output of the elemwise
6467 # So we should not lift the subtensor
@@ -144,7 +147,7 @@ def test_multinary_multiple_clients(self):
144147 ),
145148 ],
146149 )
147- def test_local_subtensor_of_elemwise (self , original_fn , expected_fn ):
150+ def test_elemwise (self , original_fn , expected_fn ):
148151 rng = np .random .default_rng (257 )
149152 x = pt .matrix ("x" , shape = (5 , 3 ))
150153 y = pt .matrix ("y" , shape = (5 , 3 ))
@@ -163,19 +166,56 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
163166 out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
164167 )
165168
166- def test_local_subtensor_of_elemwise_multiple_clients (self ):
169+ def test_elemwise_multiple_clients (self ):
167170 x = pt .matrix ("x" , shape = (5 , 3 ))
168171 y = pt .matrix ("y" , shape = (5 , 3 ))
169172 out1 = add (x , y )
170173 out2 = out1 [0 ]
171174
172175 # Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
173176 fgraph = FunctionGraph ([x , y ], [out1 , out2 ], clone = False )
174- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is None
177+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is None
175178
176179 # Otherwise it should work
177180 fgraph .remove_output (0 )
178- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is not None
181+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is not None
182+
183+ def test_blockwise (self ):
184+ class CoreTestOp (Op ):
185+ itypes = [dvector , dvector ]
186+ otypes = [dvector ]
187+
188+ def perform (self , node , inputs , output_storage ):
189+ output_storage [0 ][0 ] = np .convolve (* inputs , mode = "valid" )
190+
191+ core_test_op = CoreTestOp ()
192+ block_test_op = Blockwise (core_test_op , signature = "(a),(b)->(c)" )
193+
194+ x = tensor3 ("x" , shape = (7 , 5 , 11 ), dtype = "float64" )
195+ y = tensor ("y" , shape = (7 , 33 ), dtype = "float64" )
196+ out = block_test_op (x , y [:, None , :])
197+ assert isinstance (out .owner .op , Blockwise )
198+
199+ out_sliced = out [2 :][:, 3 :]
200+ rewritten_out_sliced = rewrite_graph (out_sliced )
201+ expected_out_sliced = block_test_op (x [2 :, 3 :], y [2 :][:, None , :])
202+ assert equal_computations ([rewritten_out_sliced ], [expected_out_sliced ])
203+
204+ rng = np .random .default_rng (191 )
205+ x_test = rng .normal (size = x .type .shape ).astype (x .type .dtype )
206+ y_test = rng .normal (size = y .type .shape ).astype (y .type .dtype )
207+ np .testing .assert_allclose (
208+ rewritten_out_sliced .eval (
209+ {x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE
210+ ),
211+ out_sliced .eval ({x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE ),
212+ )
213+
214+ # Check slice on core dims
215+ out_sliced = out [2 :][:, 0 ][:, 4 :]
216+ rewritten_out_sliced = rewrite_graph (out_sliced )
217+ expected_out_sliced = block_test_op (x [2 :, 0 ], y [2 :])[:, 4 :]
218+ assert equal_computations ([rewritten_out_sliced ], [expected_out_sliced ])
179219
180220
181221@pytest .mark .parametrize (
0 commit comments