14
14
from pytensor .graph import (
15
15
Constant ,
16
16
FunctionGraph ,
17
+ Op ,
17
18
RewriteDatabaseQuery ,
18
19
Type ,
19
20
rewrite_graph ,
23
24
from pytensor .printing import debugprint
24
25
from pytensor .tensor import (
25
26
add ,
27
+ dvector ,
26
28
exp ,
27
29
iscalar ,
28
30
iscalars ,
37
39
vector ,
38
40
)
39
41
from pytensor .tensor .basic import MakeVector , concatenate , expand_dims , make_vector
42
+ from pytensor .tensor .blockwise import Blockwise
40
43
from pytensor .tensor .elemwise import DimShuffle , Elemwise
41
44
from pytensor .tensor .math import sum as pt_sum
42
45
from pytensor .tensor .rewriting .subtensor_lift import (
43
46
local_subtensor_make_vector ,
44
- local_subtensor_of_elemwise ,
47
+ local_subtensor_of_batch_dims ,
45
48
local_subtensor_shape_constant ,
46
49
)
47
50
from pytensor .tensor .shape import SpecifyShape , _shape
58
61
NO_OPTIMIZATION_MODE = Mode (linker = "py" , optimizer = None )
59
62
60
63
61
- class TestLocalSubtensorOfElemwise :
64
+ class TestLocalSubtensorOfBatchDims :
62
65
def test_unary_multiple_clients (self ):
63
66
# as test0, but we reuse the output of the elemwise
64
67
# So we should not lift the subtensor
@@ -144,7 +147,7 @@ def test_multinary_multiple_clients(self):
144
147
),
145
148
],
146
149
)
147
- def test_local_subtensor_of_elemwise (self , original_fn , expected_fn ):
150
+ def test_elemwise (self , original_fn , expected_fn ):
148
151
rng = np .random .default_rng (257 )
149
152
x = pt .matrix ("x" , shape = (5 , 3 ))
150
153
y = pt .matrix ("y" , shape = (5 , 3 ))
@@ -163,19 +166,56 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
163
166
out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
164
167
)
165
168
166
- def test_local_subtensor_of_elemwise_multiple_clients (self ):
169
+ def test_elemwise_multiple_clients (self ):
167
170
x = pt .matrix ("x" , shape = (5 , 3 ))
168
171
y = pt .matrix ("y" , shape = (5 , 3 ))
169
172
out1 = add (x , y )
170
173
out2 = out1 [0 ]
171
174
172
175
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
173
176
fgraph = FunctionGraph ([x , y ], [out1 , out2 ], clone = False )
174
- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is None
177
+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is None
175
178
176
179
# Otherwise it should work
177
180
fgraph .remove_output (0 )
178
- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is not None
181
+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is not None
182
+
183
+ def test_blockwise (self ):
184
+ class CoreTestOp (Op ):
185
+ itypes = [dvector , dvector ]
186
+ otypes = [dvector ]
187
+
188
+ def perform (self , node , inputs , output_storage ):
189
+ output_storage [0 ][0 ] = np .convolve (* inputs , mode = "valid" )
190
+
191
+ core_test_op = CoreTestOp ()
192
+ block_test_op = Blockwise (core_test_op , signature = "(a),(b)->(c)" )
193
+
194
+ x = tensor3 ("x" , shape = (7 , 5 , 11 ), dtype = "float64" )
195
+ y = tensor ("y" , shape = (7 , 33 ), dtype = "float64" )
196
+ out = block_test_op (x , y [:, None , :])
197
+ assert isinstance (out .owner .op , Blockwise )
198
+
199
+ out_sliced = out [2 :][:, 3 :]
200
+ rewritten_out_sliced = rewrite_graph (out_sliced )
201
+ expected_out_sliced = block_test_op (x [2 :, 3 :], y [2 :][:, None , :])
202
+ assert equal_computations ([rewritten_out_sliced ], [expected_out_sliced ])
203
+
204
+ rng = np .random .default_rng (191 )
205
+ x_test = rng .normal (size = x .type .shape ).astype (x .type .dtype )
206
+ y_test = rng .normal (size = y .type .shape ).astype (y .type .dtype )
207
+ np .testing .assert_allclose (
208
+ rewritten_out_sliced .eval (
209
+ {x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE
210
+ ),
211
+ out_sliced .eval ({x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE ),
212
+ )
213
+
214
+ # Check slice on core dims
215
+ out_sliced = out [2 :][:, 0 ][:, 4 :]
216
+ rewritten_out_sliced = rewrite_graph (out_sliced )
217
+ expected_out_sliced = block_test_op (x [2 :, 0 ], y [2 :])[:, 4 :]
218
+ assert equal_computations ([rewritten_out_sliced ], [expected_out_sliced ])
179
219
180
220
181
221
@pytest .mark .parametrize (
0 commit comments