Skip to content

Commit ffcfa7d

Browse files
committed
Lift Subtensor over CAReduce
1 parent e1ee3a2 commit ffcfa7d

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Iterable, Sequence
22

33
import numpy as np
4+
from numpy.core.numeric import normalize_axis_tuple # type: ignore
45

56
from pytensor import Variable
67
from pytensor.graph import Constant, node_rewriter
@@ -15,7 +16,7 @@
1516
get_underlying_scalar_constant_value,
1617
register_infer_shape,
1718
)
18-
from pytensor.tensor.elemwise import DimShuffle, Elemwise
19+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1920
from pytensor.tensor.exceptions import NotScalarConstantError
2021
from pytensor.tensor.extra_ops import squeeze
2122
from pytensor.tensor.math import Dot, ceil_intdiv, dot
@@ -185,6 +186,63 @@ def local_subtensor_of_elemwise(fgraph, node):
185186
return [new_out]
186187

187188

189+
@register_canonicalize
190+
@register_specialize
191+
@node_rewriter([Subtensor])
192+
def local_subtensor_of_reduce(fgraph, node):
193+
"""Lift a Subtensor through a CAReduce Op.
194+
195+
For now rewrite is restricted to single axis of reduction, for simplicity.
196+
197+
sum(x, axis=1)[0] -> sum(x[0], axis=0)
198+
sum(x, axis=1)[1:] -> sum(x[1:], axis=1)
199+
sum(x, axis=0)[0] -> sum(x[:, 0], axis=0)
200+
sum(x, axis=0)[1:] -> sum(x[:, 1:], axis=0)
201+
202+
"""
203+
red, *idx = node.inputs
204+
205+
if not (red.owner and isinstance(red.owner.op, CAReduce)):
206+
return None
207+
208+
if len(fgraph.clients[red]) > 1:
209+
# Don't apply rewrite if another node requires the full reduction
210+
return None
211+
212+
[x] = red.owner.inputs
213+
axis = red.owner.op.axis
214+
215+
if axis is None:
216+
axis = tuple(range(x.type.ndim))
217+
218+
# TODO: Allow reduction across multiple axis
219+
if len(axis) != 1:
220+
return None
221+
222+
[axis] = normalize_axis_tuple(axis, x.ndim)
223+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
224+
225+
# Index input of reduction.
226+
new_idxs = list(idx_tuple)
227+
if axis < len(idx_tuple):
228+
# When there are indexes beyond the axis of reduction, we need to shift them with None slices.
229+
new_idxs.insert(axis, slice(None))
230+
x_sub = x[tuple(new_idxs)]
231+
232+
[old_out] = node.outputs
233+
copy_stack_trace(old_out, x_sub)
234+
235+
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
236+
axis -= len(
237+
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
238+
)
239+
240+
# Apply reduction to indexed input
241+
out = type(red.owner.op)(axis=axis)(x_sub)
242+
copy_stack_trace(old_out, out)
243+
return [out]
244+
245+
188246
@register_canonicalize("shape_unsafe")
189247
@register_specialize("shape_unsafe")
190248
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
4646
from pytensor.tensor.elemwise import DimShuffle, Elemwise
47+
from pytensor.tensor.math import sum as pt_sum
4748
from pytensor.tensor.rewriting.subtensor_lift import (
4849
local_subtensor_make_vector,
4950
local_subtensor_of_elemwise,
@@ -178,6 +179,40 @@ def test_local_subtensor_of_elemwise_multiple_clients():
178179
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
179180

180181

182+
@pytest.mark.parametrize(
183+
"original_fn, expected_fn",
184+
[
185+
# Indexing before axis of reduction
186+
(lambda x: pt_sum(x, axis=2)[0], lambda x: pt_sum(x[0], axis=1)),
187+
(lambda x: pt_sum(x, axis=2)[0, 1], lambda x: pt_sum(x[0, 1], axis=None)),
188+
(lambda x: pt_sum(x, axis=2)[1:], lambda x: pt_sum(x[1:], axis=2)),
189+
# Indexing "at" axis of reduction
190+
(lambda x: pt_sum(x, axis=0)[2], lambda x: pt_sum(x[:, 2], axis=0)),
191+
(lambda x: pt_sum(x, axis=0)[:-2], lambda x: pt_sum(x[:, :-2], axis=0)),
192+
# Index after axis of reduction
193+
(lambda x: pt_sum(x, axis=0)[:, 1:], lambda x: pt_sum(x[:, :, 1:], axis=0)),
194+
# Index before and after axis reduction
195+
(lambda x: pt_sum(x, axis=1)[-2, 1:], lambda x: pt_sum(x[-2, :, 1:], axis=0)),
196+
(lambda x: pt_sum(x, axis=1)[1:, -2], lambda x: pt_sum(x[1:, :, -2], axis=1)),
197+
],
198+
)
199+
def test_local_subtensor_of_reduce(original_fn, expected_fn):
200+
rng = np.random.default_rng(245)
201+
x = pt.tensor("x", shape=(5, 3, 2))
202+
x_test = rng.normal(size=x.type.shape)
203+
204+
out = original_fn(x)
205+
expected_opt_out = expected_fn(x)
206+
opt_out = rewrite_graph(out)
207+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
208+
[expected_opt_out, opt_out], print_type=True
209+
)
210+
np.testing.assert_allclose(
211+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
212+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
213+
)
214+
215+
181216
def test_local_subtensor_of_unbroadcast():
182217
# Test that Subtensor(Unbroadcast(x)) gets optimized into
183218
# Unbroadcast(Subtensor(x)).

0 commit comments

Comments
 (0)