14
14
from pytensor .graph import (
15
15
Constant ,
16
16
FunctionGraph ,
17
+ Op ,
17
18
RewriteDatabaseQuery ,
18
19
Type ,
19
20
rewrite_graph ,
23
24
from pytensor .printing import debugprint
24
25
from pytensor .tensor import (
25
26
add ,
27
+ dvector ,
26
28
exp ,
27
29
iscalar ,
28
30
iscalars ,
39
41
from pytensor .tensor .basic import MakeVector , concatenate , expand_dims , make_vector
40
42
from pytensor .tensor .blas import Dot22 , Gemv
41
43
from pytensor .tensor .blas_c import CGemv
44
+ from pytensor .tensor .blockwise import Blockwise
42
45
from pytensor .tensor .elemwise import DimShuffle , Elemwise
43
46
from pytensor .tensor .math import sum as pt_sum
44
47
from pytensor .tensor .rewriting .subtensor_lift import (
45
48
local_subtensor_make_vector ,
46
- local_subtensor_of_elemwise ,
49
+ local_subtensor_of_batch_dims ,
47
50
local_subtensor_shape_constant ,
48
51
)
49
52
from pytensor .tensor .shape import SpecifyShape , _shape
60
63
NO_OPTIMIZATION_MODE = Mode (linker = "py" , optimizer = None )
61
64
62
65
63
- class TestLocalSubtensorOfElemwise :
66
+ class TestLocalSubtensorOfBatchDims :
64
67
def test_unary_multiple_clients (self ):
65
68
# as test0, but we reuse the output of the elemwise
66
69
# So we should not lift the subtensor
@@ -146,7 +149,7 @@ def test_multinary_multiple_clients(self):
146
149
),
147
150
],
148
151
)
149
- def test_local_subtensor_of_elemwise (self , original_fn , expected_fn ):
152
+ def test_elemwise (self , original_fn , expected_fn ):
150
153
rng = np .random .default_rng (257 )
151
154
x = pt .matrix ("x" , shape = (5 , 3 ))
152
155
y = pt .matrix ("y" , shape = (5 , 3 ))
@@ -165,19 +168,56 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
165
168
out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
166
169
)
167
170
168
- def test_local_subtensor_of_elemwise_multiple_clients (self ):
171
+ def test_elemwise_multiple_clients (self ):
169
172
x = pt .matrix ("x" , shape = (5 , 3 ))
170
173
y = pt .matrix ("y" , shape = (5 , 3 ))
171
174
out1 = add (x , y )
172
175
out2 = out1 [0 ]
173
176
174
177
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
175
178
fgraph = FunctionGraph ([x , y ], [out1 , out2 ], clone = False )
176
- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is None
179
+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is None
177
180
178
181
# Otherwise it should work
179
182
fgraph .remove_output (0 )
180
- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is not None
183
+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is not None
184
+
185
+ def test_blockwise (self ):
186
+ class CoreTestOp (Op ):
187
+ itypes = [dvector , dvector ]
188
+ otypes = [dvector ]
189
+
190
+ def perform (self , node , inputs , output_storage ):
191
+ output_storage [0 ][0 ] = np .convolve (* inputs , mode = "valid" )
192
+
193
+ core_test_op = CoreTestOp ()
194
+ block_test_op = Blockwise (core_test_op , signature = "(a),(b)->(c)" )
195
+
196
+ x = tensor3 ("x" , shape = (7 , 5 , 11 ), dtype = "float64" )
197
+ y = tensor ("y" , shape = (7 , 33 ), dtype = "float64" )
198
+ out = block_test_op (x , y [:, None , :])
199
+ assert isinstance (out .owner .op , Blockwise )
200
+
201
+ out_sliced = out [2 :][:, 3 :]
202
+ rewritten_out_sliced = rewrite_graph (out_sliced )
203
+ expected_out_sliced = block_test_op (x [2 :, 3 :], y [2 :][:, None , :])
204
+ assert equal_computations ([rewritten_out_sliced ], [expected_out_sliced ])
205
+
206
+ rng = np .random .default_rng (191 )
207
+ x_test = rng .normal (size = x .type .shape ).astype (x .type .dtype )
208
+ y_test = rng .normal (size = y .type .shape ).astype (y .type .dtype )
209
+ np .testing .assert_allclose (
210
+ rewritten_out_sliced .eval (
211
+ {x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE
212
+ ),
213
+ out_sliced .eval ({x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE ),
214
+ )
215
+
216
+ # Check slice on core dims
217
+ out_sliced = out [2 :][:, 0 ][:, 4 :]
218
+ rewritten_out_sliced = rewrite_graph (out_sliced )
219
+ expected_out_sliced = block_test_op (x [2 :, 0 ], y [2 :])[:, 4 :]
220
+ assert equal_computations ([rewritten_out_sliced ], [expected_out_sliced ])
181
221
182
222
183
223
def test_local_subtensor_of_dot ():
0 commit comments