Skip to content

Commit 0960e05

Browse files
committed
Lift Subtensor over CAReduce
1 parent b4f178a commit 0960e05

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor import Variable
66
from pytensor.graph import Constant, node_rewriter
77
from pytensor.graph.rewriting.basic import copy_stack_trace
8+
from pytensor.npy_2_compat import normalize_axis_tuple
89
from pytensor.scalar import basic as ps
910
from pytensor.tensor.basic import (
1011
Alloc,
@@ -15,7 +16,7 @@
1516
get_underlying_scalar_constant_value,
1617
register_infer_shape,
1718
)
18-
from pytensor.tensor.elemwise import DimShuffle, Elemwise
19+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1920
from pytensor.tensor.exceptions import NotScalarConstantError
2021
from pytensor.tensor.extra_ops import squeeze
2122
from pytensor.tensor.math import Dot, ceil_intdiv, dot
@@ -185,6 +186,63 @@ def local_subtensor_of_elemwise(fgraph, node):
185186
return [new_out]
186187

187188

189+
@register_canonicalize
190+
@register_specialize
191+
@node_rewriter([Subtensor])
192+
def local_subtensor_of_reduce(fgraph, node):
193+
"""Lift a Subtensor through a CAReduce Op.
194+
195+
For now rewrite is restricted to single axis of reduction, for simplicity.
196+
197+
sum(x, axis=1)[0] -> sum(x[0], axis=0)
198+
sum(x, axis=1)[1:] -> sum(x[1:], axis=1)
199+
sum(x, axis=0)[0] -> sum(x[:, 0], axis=0)
200+
sum(x, axis=0)[1:] -> sum(x[:, 1:], axis=0)
201+
202+
"""
203+
red, *idx = node.inputs
204+
205+
if not (red.owner and isinstance(red.owner.op, CAReduce)):
206+
return None
207+
208+
if len(fgraph.clients[red]) > 1:
209+
# Don't apply rewrite if another node requires the full reduction
210+
return None
211+
212+
[x] = red.owner.inputs
213+
axis = red.owner.op.axis
214+
215+
if axis is None:
216+
axis = tuple(range(x.type.ndim))
217+
218+
# TODO: Allow reduction across multiple axis
219+
if len(axis) != 1:
220+
return None
221+
222+
[axis] = normalize_axis_tuple(axis, x.ndim)
223+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
224+
225+
# Index input of reduction.
226+
new_idxs = list(idx_tuple)
227+
if axis < len(idx_tuple):
228+
# When there are indexes beyond the axis of reduction, we need to shift them with None slices.
229+
new_idxs.insert(axis, slice(None))
230+
x_sub = x[tuple(new_idxs)]
231+
232+
[old_out] = node.outputs
233+
copy_stack_trace(old_out, x_sub)
234+
235+
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
236+
axis -= len(
237+
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
238+
)
239+
240+
# Apply reduction to indexed input
241+
out = type(red.owner.op)(axis=axis)(x_sub)
242+
copy_stack_trace(old_out, out)
243+
return [out]
244+
245+
188246
@register_canonicalize("shape_unsafe")
189247
@register_specialize("shape_unsafe")
190248
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
4141
from pytensor.tensor.elemwise import DimShuffle, Elemwise
42+
from pytensor.tensor.math import sum as pt_sum
4243
from pytensor.tensor.rewriting.subtensor_lift import (
4344
local_subtensor_make_vector,
4445
local_subtensor_of_elemwise,
@@ -177,6 +178,40 @@ def test_local_subtensor_of_elemwise_multiple_clients(self):
177178
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
178179

179180

181+
@pytest.mark.parametrize(
182+
"original_fn, expected_fn",
183+
[
184+
# Indexing before axis of reduction
185+
(lambda x: pt_sum(x, axis=2)[0], lambda x: pt_sum(x[0], axis=1)),
186+
(lambda x: pt_sum(x, axis=2)[0, 1], lambda x: pt_sum(x[0, 1], axis=None)),
187+
(lambda x: pt_sum(x, axis=2)[1:], lambda x: pt_sum(x[1:], axis=2)),
188+
# Indexing "at" axis of reduction
189+
(lambda x: pt_sum(x, axis=0)[2], lambda x: pt_sum(x[:, 2], axis=0)),
190+
(lambda x: pt_sum(x, axis=0)[:-2], lambda x: pt_sum(x[:, :-2], axis=0)),
191+
# Index after axis of reduction
192+
(lambda x: pt_sum(x, axis=0)[:, 1:], lambda x: pt_sum(x[:, :, 1:], axis=0)),
193+
# Index before and after axis reduction
194+
(lambda x: pt_sum(x, axis=1)[-2, 1:], lambda x: pt_sum(x[-2, :, 1:], axis=0)),
195+
(lambda x: pt_sum(x, axis=1)[1:, -2], lambda x: pt_sum(x[1:, :, -2], axis=1)),
196+
],
197+
)
198+
def test_local_subtensor_of_reduce(original_fn, expected_fn):
199+
rng = np.random.default_rng(245)
200+
x = pt.tensor("x", shape=(5, 3, 2))
201+
x_test = rng.normal(size=x.type.shape)
202+
203+
out = original_fn(x)
204+
expected_opt_out = expected_fn(x)
205+
opt_out = rewrite_graph(out)
206+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
207+
[expected_opt_out, opt_out], print_type=True
208+
)
209+
np.testing.assert_allclose(
210+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
211+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
212+
)
213+
214+
180215
def test_local_subtensor_of_unbroadcast():
181216
# Test that Subtensor(Unbroadcast(x)) gets optimized into
182217
# Unbroadcast(Subtensor(x)).

0 commit comments

Comments
 (0)