Skip to content

Commit 62d2ab2

Browse files
committed
Generalize local_subtensor_of_elemwise to Blockwise
1 parent 9fa3885 commit 62d2ab2

File tree

2 files changed

+78
-12
lines changed

2 files changed

+78
-12
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
join,
2121
register_infer_shape,
2222
)
23+
from pytensor.tensor.blockwise import Blockwise
2324
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2425
from pytensor.tensor.exceptions import NotScalarConstantError
2526
from pytensor.tensor.extra_ops import squeeze
@@ -169,16 +170,16 @@ def local_subtensor_of_dot(fgraph, node):
169170
@register_canonicalize("shape_unsafe")
170171
@register_specialize("shape_unsafe")
171172
@node_rewriter([Subtensor])
172-
def local_subtensor_of_elemwise(fgraph, node):
173-
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
173+
def local_subtensor_of_batch_dims(fgraph, node):
174+
"""Lift a Subtensor through the batch dims of an (Elemwise or Blockwise) operation and its implicit broadcasting behavior.
174175
175176
exp(x)[:, 0] -> exp(x[:, 0])
176177
add(x, y)[0] -> add(x[0], y[0])
177178
add(x[None], y)[2] -> add(x, y[2])
178179
"""
179180
elem, *idx = node.inputs
180181

181-
if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
182+
if not (elem.owner and isinstance(elem.owner.op, Elemwise | Blockwise)):
182183
return None
183184

184185
if len(fgraph.clients[elem]) > 1:
@@ -188,9 +189,34 @@ def local_subtensor_of_elemwise(fgraph, node):
188189

189190
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
190191

192+
batch_ndim = (
193+
elem.owner.op.batch_ndim(elem.owner)
194+
if isinstance(elem.owner.op, Blockwise)
195+
else elem.ndim
196+
)
197+
198+
if len(idx_tuple) > batch_ndim:
199+
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
200+
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
201+
if all(is_full_slice(idx) for idx in batch_indices):
202+
# No batch indices, nothing to do
203+
return None
204+
elem_with_batch_indices = elem[batch_indices]
205+
[elem_with_batch_indices_lifted] = local_subtensor_of_batch_dims.transform(
206+
fgraph, elem_with_batch_indices.owner
207+
)
208+
# Reapply the core_indices
209+
core_ndim = elem.type.ndim - batch_ndim
210+
# Number of batch dims may have changed with the lifting of indices, so we recompute
211+
new_batch_ndim = elem_with_batch_indices_lifted.type.ndim - core_ndim
212+
new_indices = (*(slice(None),) * new_batch_ndim, *core_indices)
213+
new_elem = elem_with_batch_indices_lifted[new_indices]
214+
copy_stack_trace(node.outputs[0], new_elem)
215+
return [new_elem]
216+
191217
elem_inputs = elem.owner.inputs
192-
elem_bcast = elem.type.broadcastable
193-
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
218+
elem_bcast = elem.type.broadcastable[:batch_ndim]
219+
if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs):
194220
# No need to worry about implicit broadcasting.
195221
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
196222

@@ -201,7 +227,7 @@ def local_subtensor_of_elemwise(fgraph, node):
201227
zip(
202228
idx_tuple,
203229
elem_bcast,
204-
*(inp.type.broadcastable for inp in elem_inputs),
230+
*(inp.type.broadcastable[:batch_ndim] for inp in elem_inputs),
205231
# Indices can be shorter than input ndims
206232
strict=False,
207233
)

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pytensor.graph import (
1515
Constant,
1616
FunctionGraph,
17+
Op,
1718
RewriteDatabaseQuery,
1819
Type,
1920
rewrite_graph,
@@ -23,6 +24,7 @@
2324
from pytensor.printing import debugprint
2425
from pytensor.tensor import (
2526
add,
27+
dvector,
2628
exp,
2729
iscalar,
2830
iscalars,
@@ -37,11 +39,12 @@
3739
vector,
3840
)
3941
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
42+
from pytensor.tensor.blockwise import Blockwise
4043
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4144
from pytensor.tensor.math import sum as pt_sum
4245
from pytensor.tensor.rewriting.subtensor_lift import (
4346
local_subtensor_make_vector,
44-
local_subtensor_of_elemwise,
47+
local_subtensor_of_batch_dims,
4548
local_subtensor_shape_constant,
4649
)
4750
from pytensor.tensor.shape import SpecifyShape, _shape
@@ -58,7 +61,7 @@
5861
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
5962

6063

61-
class TestLocalSubtensorOfElemwise:
64+
class TestLocalSubtensorOfBatchDims:
6265
def test_unary_multiple_clients(self):
6366
# as test0, but we reuse the output of the elemwise
6467
# So we should not lift the subtensor
@@ -144,7 +147,7 @@ def test_multinary_multiple_clients(self):
144147
),
145148
],
146149
)
147-
def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
150+
def test_elemwise(self, original_fn, expected_fn):
148151
rng = np.random.default_rng(257)
149152
x = pt.matrix("x", shape=(5, 3))
150153
y = pt.matrix("y", shape=(5, 3))
@@ -163,19 +166,56 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
163166
out.eval({x: x_test, y: y_test}, **eval_kwargs),
164167
)
165168

166-
def test_local_subtensor_of_elemwise_multiple_clients(self):
169+
def test_elemwise_multiple_clients(self):
167170
x = pt.matrix("x", shape=(5, 3))
168171
y = pt.matrix("y", shape=(5, 3))
169172
out1 = add(x, y)
170173
out2 = out1[0]
171174

172175
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
173176
fgraph = FunctionGraph([x, y], [out1, out2], clone=False)
174-
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None
177+
assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is None
175178

176179
# Otherwise it should work
177180
fgraph.remove_output(0)
178-
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
181+
assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is not None
182+
183+
def test_blockwise(self):
184+
class CoreTestOp(Op):
185+
itypes = [dvector, dvector]
186+
otypes = [dvector]
187+
188+
def perform(self, node, inputs, output_storage):
189+
output_storage[0][0] = np.convolve(*inputs, mode="valid")
190+
191+
core_test_op = CoreTestOp()
192+
block_test_op = Blockwise(core_test_op, signature="(a),(b)->(c)")
193+
194+
x = tensor3("x", shape=(7, 5, 11), dtype="float64")
195+
y = tensor("y", shape=(7, 33), dtype="float64")
196+
out = block_test_op(x, y[:, None, :])
197+
assert isinstance(out.owner.op, Blockwise)
198+
199+
out_sliced = out[2:][:, 3:]
200+
rewritten_out_sliced = rewrite_graph(out_sliced)
201+
expected_out_sliced = block_test_op(x[2:, 3:], y[2:][:, None, :])
202+
assert equal_computations([rewritten_out_sliced], [expected_out_sliced])
203+
204+
rng = np.random.default_rng(191)
205+
x_test = rng.normal(size=x.type.shape).astype(x.type.dtype)
206+
y_test = rng.normal(size=y.type.shape).astype(y.type.dtype)
207+
np.testing.assert_allclose(
208+
rewritten_out_sliced.eval(
209+
{x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE
210+
),
211+
out_sliced.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
212+
)
213+
214+
# Check slice on core dims
215+
out_sliced = out[2:][:, 0][:, 4:]
216+
rewritten_out_sliced = rewrite_graph(out_sliced)
217+
expected_out_sliced = block_test_op(x[2:, 0], y[2:])[:, 4:]
218+
assert equal_computations([rewritten_out_sliced], [expected_out_sliced])
179219

180220

181221
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)