Skip to content

Commit 220fef2

Browse files
committed
Generalize local_subtensor_of_elemwise to Blockwise
1 parent 49f83bc commit 220fef2

File tree

2 files changed

+78
-12
lines changed

2 files changed

+78
-12
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
join,
2121
register_infer_shape,
2222
)
23+
from pytensor.tensor.blockwise import Blockwise
2324
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2425
from pytensor.tensor.exceptions import NotScalarConstantError
2526
from pytensor.tensor.extra_ops import squeeze
@@ -169,16 +170,16 @@ def local_subtensor_of_dot(fgraph, node):
169170
@register_canonicalize("shape_unsafe")
170171
@register_specialize("shape_unsafe")
171172
@node_rewriter([Subtensor])
172-
def local_subtensor_of_elemwise(fgraph, node):
173-
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
173+
def local_subtensor_of_batch_dims(fgraph, node):
174+
"""Lift a Subtensor through the batch dims of an (Elemwise or Blockwise) operation and its implicit broadcasting behavior.
174175
175176
exp(x)[:, 0] -> exp(x[:, 0])
176177
add(x, y)[0] -> add(x[0], y[0])
177178
add(x[None], y)[2] -> add(x, y[2])
178179
"""
179180
elem, *idx = node.inputs
180181

181-
if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
182+
if not (elem.owner and isinstance(elem.owner.op, Elemwise | Blockwise)):
182183
return None
183184

184185
if len(fgraph.clients[elem]) > 1:
@@ -188,9 +189,34 @@ def local_subtensor_of_elemwise(fgraph, node):
188189

189190
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
190191

192+
batch_ndim = (
193+
elem.owner.op.batch_ndim(elem.owner)
194+
if isinstance(elem.owner.op, Blockwise)
195+
else elem.ndim
196+
)
197+
198+
if len(idx_tuple) > batch_ndim:
199+
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
200+
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
201+
if all(is_full_slice(idx) for idx in batch_indices):
202+
# No batch indices, nothing to do
203+
return None
204+
elem_with_batch_indices = elem[batch_indices]
205+
[elem_with_batch_indices_lifted] = local_subtensor_of_batch_dims.transform(
206+
fgraph, elem_with_batch_indices.owner
207+
)
208+
# Reapply the core_indices
209+
core_ndim = elem.type.ndim - batch_ndim
210+
# Number of batch dims may have changed with the lifting of indices, so we recompute
211+
new_batch_ndim = elem_with_batch_indices_lifted.type.ndim - core_ndim
212+
new_indices = (*(slice(None),) * new_batch_ndim, *core_indices)
213+
new_elem = elem_with_batch_indices_lifted[new_indices]
214+
copy_stack_trace(node.outputs[0], new_elem)
215+
return [new_elem]
216+
191217
elem_inputs = elem.owner.inputs
192-
elem_bcast = elem.type.broadcastable
193-
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
218+
elem_bcast = elem.type.broadcastable[:batch_ndim]
219+
if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs):
194220
# No need to worry about implicit broadcasting.
195221
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
196222

@@ -201,7 +227,7 @@ def local_subtensor_of_elemwise(fgraph, node):
201227
zip(
202228
idx_tuple,
203229
elem_bcast,
204-
*(inp.type.broadcastable for inp in elem_inputs),
230+
*(inp.type.broadcastable[:batch_ndim] for inp in elem_inputs),
205231
# Indices can be shorter than input ndims
206232
strict=False,
207233
)

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pytensor.graph import (
1515
Constant,
1616
FunctionGraph,
17+
Op,
1718
RewriteDatabaseQuery,
1819
Type,
1920
rewrite_graph,
@@ -23,6 +24,7 @@
2324
from pytensor.printing import debugprint
2425
from pytensor.tensor import (
2526
add,
27+
dvector,
2628
exp,
2729
iscalar,
2830
iscalars,
@@ -39,11 +41,12 @@
3941
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
4042
from pytensor.tensor.blas import Dot22, Gemv
4143
from pytensor.tensor.blas_c import CGemv
44+
from pytensor.tensor.blockwise import Blockwise
4245
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4346
from pytensor.tensor.math import sum as pt_sum
4447
from pytensor.tensor.rewriting.subtensor_lift import (
4548
local_subtensor_make_vector,
46-
local_subtensor_of_elemwise,
49+
local_subtensor_of_batch_dims,
4750
local_subtensor_shape_constant,
4851
)
4952
from pytensor.tensor.shape import SpecifyShape, _shape
@@ -60,7 +63,7 @@
6063
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
6164

6265

63-
class TestLocalSubtensorOfElemwise:
66+
class TestLocalSubtensorOfBatchDims:
6467
def test_unary_multiple_clients(self):
6568
# as test0, but we reuse the output of the elemwise
6669
# So we should not lift the subtensor
@@ -146,7 +149,7 @@ def test_multinary_multiple_clients(self):
146149
),
147150
],
148151
)
149-
def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
152+
def test_elemwise(self, original_fn, expected_fn):
150153
rng = np.random.default_rng(257)
151154
x = pt.matrix("x", shape=(5, 3))
152155
y = pt.matrix("y", shape=(5, 3))
@@ -165,19 +168,56 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
165168
out.eval({x: x_test, y: y_test}, **eval_kwargs),
166169
)
167170

168-
def test_local_subtensor_of_elemwise_multiple_clients(self):
171+
def test_elemwise_multiple_clients(self):
169172
x = pt.matrix("x", shape=(5, 3))
170173
y = pt.matrix("y", shape=(5, 3))
171174
out1 = add(x, y)
172175
out2 = out1[0]
173176

174177
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
175178
fgraph = FunctionGraph([x, y], [out1, out2], clone=False)
176-
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None
179+
assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is None
177180

178181
# Otherwise it should work
179182
fgraph.remove_output(0)
180-
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
183+
assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is not None
184+
185+
def test_blockwise(self):
186+
class CoreTestOp(Op):
187+
itypes = [dvector, dvector]
188+
otypes = [dvector]
189+
190+
def perform(self, node, inputs, output_storage):
191+
output_storage[0][0] = np.convolve(*inputs, mode="valid")
192+
193+
core_test_op = CoreTestOp()
194+
block_test_op = Blockwise(core_test_op, signature="(a),(b)->(c)")
195+
196+
x = tensor3("x", shape=(7, 5, 11), dtype="float64")
197+
y = tensor("y", shape=(7, 33), dtype="float64")
198+
out = block_test_op(x, y[:, None, :])
199+
assert isinstance(out.owner.op, Blockwise)
200+
201+
out_sliced = out[2:][:, 3:]
202+
rewritten_out_sliced = rewrite_graph(out_sliced)
203+
expected_out_sliced = block_test_op(x[2:, 3:], y[2:][:, None, :])
204+
assert equal_computations([rewritten_out_sliced], [expected_out_sliced])
205+
206+
rng = np.random.default_rng(191)
207+
x_test = rng.normal(size=x.type.shape).astype(x.type.dtype)
208+
y_test = rng.normal(size=y.type.shape).astype(y.type.dtype)
209+
np.testing.assert_allclose(
210+
rewritten_out_sliced.eval(
211+
{x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE
212+
),
213+
out_sliced.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
214+
)
215+
216+
# Check slice on core dims
217+
out_sliced = out[2:][:, 0][:, 4:]
218+
rewritten_out_sliced = rewrite_graph(out_sliced)
219+
expected_out_sliced = block_test_op(x[2:, 0], y[2:])[:, 4:]
220+
assert equal_computations([rewritten_out_sliced], [expected_out_sliced])
181221

182222

183223
def test_local_subtensor_of_dot():

0 commit comments

Comments
 (0)