Skip to content

Commit eb11f0f

Browse files
committed
Lift Subtensor over Softmax
1 parent b673bc9 commit eb11f0f

File tree

2 files changed

+134
-2
lines changed

2 files changed

+134
-2
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from collections.abc import Iterable
1+
from collections.abc import Iterable, Sequence
22

33
import numpy as np
4-
from numpy.core.numeric import normalize_axis_tuple
4+
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
55

66
from pytensor import Variable
77
from pytensor.graph import Constant, node_rewriter
@@ -29,6 +29,7 @@
2929
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
3030
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
3131
from pytensor.tensor.shape import Shape, SpecifyShape, Unbroadcast, unbroadcast
32+
from pytensor.tensor.special import Softmax, softmax
3233
from pytensor.tensor.subtensor import (
3334
AdvancedSubtensor1,
3435
Subtensor,
@@ -42,9 +43,23 @@
4243

4344

4445
def _dims_dropped_by_basic_index(idxs) -> tuple[int, ...]:
46+
# Inputs can be slice or integer indexes
47+
# Slices keep the dimensions, integers collapse them
4548
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
4649

4750

51+
def _ndim_dropped_left_of_axis_by_basic_index(idxs, axis: int) -> int:
52+
return len(_dims_dropped_by_basic_index(idxs[:axis]))
53+
54+
55+
def _axis_is_indexed_by_basic_index(
56+
idxs: tuple[Variable], axis: int | Sequence[int]
57+
) -> bool:
58+
if isinstance(axis, int):
59+
axis = (axis,)
60+
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
61+
62+
4863
@register_canonicalize
4964
@register_stabilize
5065
@register_specialize
@@ -235,6 +250,84 @@ def local_subtensor_of_reduce(fgraph, node):
235250
return [out]
236251

237252

253+
@register_canonicalize
254+
@register_specialize
255+
@node_rewriter([Subtensor])
256+
def local_subtensor_of_softmax(fgraph, node):
257+
"""Lift a Subtensor through a Softmax.
258+
259+
softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
260+
softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
261+
262+
If part of the indexing acts on the axis of reduction, we split it
263+
softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
264+
265+
"""
266+
sm, *idx = node.inputs
267+
268+
if not (sm.owner and isinstance(sm.owner.op, Softmax)):
269+
return None
270+
271+
if len(fgraph.clients[sm]) > 1:
272+
return None
273+
274+
[x] = sm.owner.inputs
275+
axis = sm.owner.op.axis
276+
277+
if axis is None:
278+
if x.type.ndim == 1:
279+
axis = 0
280+
else:
281+
# All dimensions are mixed, we can't lift the subtensor
282+
return None
283+
else:
284+
# Softmax currently only allows None or a single integer axis
285+
# Unlike CAReduce it does not normalize negative indices
286+
axis = normalize_axis_index(axis, sm.ndim)
287+
288+
[old_out] = node.outputs
289+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
290+
291+
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
292+
# If there are more dimensions being indexed, we can split them
293+
# And lift the non-axis indexes while keeping the axis index
294+
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
295+
if len(real_indices) > 1 and sm.type.ndim > 1:
296+
# Split the subtensor
297+
idx_to_keep = idx_tuple[axis]
298+
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
299+
300+
# Lift the non-axis indexes by calling the rewrite itself
301+
opt_sm = sm[idxs_to_lift]
302+
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
303+
copy_stack_trace([old_out, sm], opt_sm)
304+
305+
# Then reintroduce the axis index
306+
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
307+
idx_tuple, axis
308+
)
309+
new_axis = axis - ndim_reduced_left
310+
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
311+
new_out = opt_sm[idxs_to_keep]
312+
copy_stack_trace(old_out, new_out)
313+
return [new_out]
314+
315+
else:
316+
return None
317+
318+
# Index input to softmax
319+
x_sub = x[idx_tuple]
320+
321+
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
322+
axis -= len(
323+
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
324+
)
325+
326+
out = softmax(x_sub, axis=axis)
327+
copy_stack_trace(old_out, out)
328+
return [out]
329+
330+
238331
@register_canonicalize("shape_unsafe")
239332
@register_specialize("shape_unsafe")
240333
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
local_subtensor_shape_constant,
5252
)
5353
from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape
54+
from pytensor.tensor.special import softmax
5455
from pytensor.tensor.subtensor import Subtensor
5556

5657

@@ -213,6 +214,44 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn):
213214
)
214215

215216

217+
@pytest.mark.parametrize(
218+
"original_fn, expected_fn",
219+
[
220+
# Lift single index that does not ovelap with axis of softmax
221+
(lambda x: softmax(x, axis=1)[0], lambda x: softmax(x[0], axis=0)),
222+
(lambda x: softmax(x, axis=1)[1:], lambda x: softmax(x[1:], axis=1)),
223+
(lambda x: softmax(x, axis=0)[:, 0], lambda x: softmax(x[:, 0], axis=0)),
224+
(lambda x: softmax(x, axis=0)[:, 1:], lambda x: softmax(x[:, 1:], axis=0)),
225+
# Do nothing to single index over axis of softmax
226+
(lambda x: softmax(x, axis=0)[0], lambda x: softmax(x, axis=0)[0]),
227+
(lambda x: softmax(x, axis=1)[:, 1:], lambda x: softmax(x, axis=1)[:, 1:]),
228+
# Split indexing on axis of softmax
229+
(lambda x: softmax(x, axis=0)[1:, 0], lambda x: softmax(x[:, 0], axis=0)[1:]),
230+
(lambda x: softmax(x, axis=1)[1:, 0], lambda x: softmax(x[1:], axis=1)[:, 0]),
231+
(
232+
lambda x: softmax(x, axis=0)[0, :5:2],
233+
lambda x: softmax(x[:, :5:2], axis=0)[0],
234+
),
235+
(lambda x: softmax(x, axis=1)[0, :5:2], lambda x: softmax(x[0], axis=0)[:5:2]),
236+
],
237+
)
238+
def test_local_subtensor_of_softmax(original_fn, expected_fn):
239+
rng = np.random.default_rng(230)
240+
x = pt.matrix("x", shape=(5, 3))
241+
x_test = rng.normal(size=x.type.shape)
242+
243+
out = original_fn(x)
244+
expected_opt_out = expected_fn(x)
245+
opt_out = rewrite_graph(out)
246+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
247+
[expected_opt_out, opt_out], print_type=True
248+
)
249+
np.testing.assert_allclose(
250+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
251+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
252+
)
253+
254+
216255
def test_local_subtensor_of_unbroadcast():
217256
# Test that Subtensor(Unbroadcast(x)) gets optimized into
218257
# Unbroadcast(Subtensor(x)).

0 commit comments

Comments
 (0)