Skip to content

Commit 0f3edf9

Browse files
committed
Lift Subtensor over CAReduce
1 parent bc9baac commit 0f3edf9

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Iterable
22

33
import numpy as np
4+
from numpy.core.numeric import normalize_axis_tuple
45

56
from pytensor import Variable
67
from pytensor.graph import Constant, node_rewriter
@@ -15,7 +16,7 @@
1516
get_underlying_scalar_constant_value,
1617
register_infer_shape,
1718
)
18-
from pytensor.tensor.elemwise import DimShuffle, Elemwise
19+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1920
from pytensor.tensor.exceptions import NotScalarConstantError
2021
from pytensor.tensor.extra_ops import squeeze
2122
from pytensor.tensor.math import Dot, ceil_intdiv, dot
@@ -183,6 +184,63 @@ def local_subtensor_of_elemwise(fgraph, node):
183184
return [new_out]
184185

185186

187+
@register_canonicalize
188+
@register_specialize
189+
@node_rewriter([Subtensor])
190+
def local_subtensor_of_reduce(fgraph, node):
191+
"""Lift a Subtensor through a CAReduce Op.
192+
193+
For now rewrite is restricted to single axis of reduction, for simplicity.
194+
195+
sum(x, axis=1)[0] -> sum(x[0], axis=0)
196+
sum(x, axis=1)[1:] -> sum(x[1:], axis=1)
197+
sum(x, axis=0)[0] -> sum(x[:, 0], axis=0)
198+
sum(x, axis=0)[1:] -> sum(x[:, 1:], axis=0)
199+
200+
"""
201+
red, *idx = node.inputs
202+
203+
if not (red.owner and isinstance(red.owner.op, CAReduce)):
204+
return None
205+
206+
if len(fgraph.clients[red]) > 1:
207+
# Don't apply rewrite if another node requires the full reduction
208+
return None
209+
210+
[x] = red.owner.inputs
211+
axis = red.owner.op.axis
212+
213+
if axis is None:
214+
axis = tuple(range(x.type.ndim))
215+
216+
# TODO: Allow reduction across multiple axis
217+
if len(axis) != 1:
218+
return None
219+
220+
[axis] = normalize_axis_tuple(axis, x.ndim)
221+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
222+
223+
# Index input of reduction.
224+
new_idxs = list(idx_tuple)
225+
if axis < len(idx_tuple):
226+
# When there are indexes beyond the axis of reduction, we need to shift them with None slices.
227+
new_idxs.insert(axis, slice(None))
228+
x_sub = x[tuple(new_idxs)]
229+
230+
[old_out] = node.outputs
231+
copy_stack_trace(old_out, x_sub)
232+
233+
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
234+
axis -= len(
235+
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
236+
)
237+
238+
# Apply reduction to indexed input
239+
out = type(red.owner.op)(axis=axis)(x_sub)
240+
copy_stack_trace(old_out, out)
241+
return [out]
242+
243+
186244
@register_canonicalize("shape_unsafe")
187245
@register_specialize("shape_unsafe")
188246
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
4646
from pytensor.tensor.elemwise import DimShuffle, Elemwise
47+
from pytensor.tensor.math import sum as pt_sum
4748
from pytensor.tensor.rewriting.subtensor_lift import (
4849
local_subtensor_make_vector,
4950
local_subtensor_of_elemwise,
@@ -178,6 +179,40 @@ def test_local_subtensor_of_elemwise_multiple_clients():
178179
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
179180

180181

182+
@pytest.mark.parametrize(
183+
"original_fn, expected_fn",
184+
[
185+
# Indexing before axis of reduction
186+
(lambda x: pt_sum(x, axis=2)[0], lambda x: pt_sum(x[0], axis=1)),
187+
(lambda x: pt_sum(x, axis=2)[0, 1], lambda x: pt_sum(x[0, 1], axis=None)),
188+
(lambda x: pt_sum(x, axis=2)[1:], lambda x: pt_sum(x[1:], axis=2)),
189+
# Indexing "at" axis of reduction
190+
(lambda x: pt_sum(x, axis=0)[2], lambda x: pt_sum(x[:, 2], axis=0)),
191+
(lambda x: pt_sum(x, axis=0)[:-2], lambda x: pt_sum(x[:, :-2], axis=0)),
192+
# Index after axis of reduction
193+
(lambda x: pt_sum(x, axis=0)[:, 1:], lambda x: pt_sum(x[:, :, 1:], axis=0)),
194+
# Index before and after axis reduction
195+
(lambda x: pt_sum(x, axis=1)[-2, 1:], lambda x: pt_sum(x[-2, :, 1:], axis=0)),
196+
(lambda x: pt_sum(x, axis=1)[1:, -2], lambda x: pt_sum(x[1:, :, -2], axis=1)),
197+
],
198+
)
199+
def test_local_subtensor_of_reduce(original_fn, expected_fn):
200+
rng = np.random.default_rng(245)
201+
x = pt.tensor("x", shape=(5, 3, 2))
202+
x_test = rng.normal(size=x.type.shape)
203+
204+
out = original_fn(x)
205+
expected_opt_out = expected_fn(x)
206+
opt_out = rewrite_graph(out)
207+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
208+
[expected_opt_out, opt_out], print_type=True
209+
)
210+
np.testing.assert_allclose(
211+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
212+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
213+
)
214+
215+
181216
def test_local_subtensor_of_unbroadcast():
182217
# Test that Subtensor(Unbroadcast(x)) gets optimized into
183218
# Unbroadcast(Subtensor(x)).

0 commit comments

Comments
 (0)