Skip to content

Commit 844ae15

Browse files
committed
Lift Subtensor over Softmax
1 parent ffcfa7d commit 844ae15

File tree

2 files changed

+136
-1
lines changed

2 files changed

+136
-1
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from collections.abc import Iterable, Sequence
22

33
import numpy as np
4-
from numpy.core.numeric import normalize_axis_tuple # type: ignore
4+
from numpy.core.numeric import ( # type: ignore
5+
normalize_axis_index,
6+
normalize_axis_tuple,
7+
)
58

69
from pytensor import Variable
710
from pytensor.graph import Constant, node_rewriter
@@ -34,6 +37,7 @@
3437
specify_shape,
3538
unbroadcast,
3639
)
40+
from pytensor.tensor.special import Softmax, softmax
3741
from pytensor.tensor.subtensor import (
3842
AdvancedSubtensor1,
3943
Subtensor,
@@ -53,6 +57,20 @@ def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]
5357
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
5458

5559

60+
def _ndim_dropped_left_of_axis_by_basic_index(
61+
idxs: Sequence[slice | int], axis: int
62+
) -> int:
63+
return len(_dims_dropped_by_basic_index(idxs[:axis]))
64+
65+
66+
def _axis_is_indexed_by_basic_index(
67+
idxs: Sequence[slice | int], axis: int | Sequence[int]
68+
) -> bool:
69+
if isinstance(axis, int):
70+
axis = (axis,)
71+
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
72+
73+
5674
@register_canonicalize
5775
@register_stabilize
5876
@register_specialize
@@ -243,6 +261,84 @@ def local_subtensor_of_reduce(fgraph, node):
243261
return [out]
244262

245263

264+
@register_canonicalize
265+
@register_specialize
266+
@node_rewriter([Subtensor])
267+
def local_subtensor_of_softmax(fgraph, node):
268+
"""Lift a Subtensor through a Softmax.
269+
270+
softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
271+
softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
272+
273+
If part of the indexing acts on the axis of reduction, we split it
274+
softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
275+
276+
"""
277+
sm, *idx = node.inputs
278+
279+
if not (sm.owner and isinstance(sm.owner.op, Softmax)):
280+
return None
281+
282+
if len(fgraph.clients[sm]) > 1:
283+
return None
284+
285+
[x] = sm.owner.inputs
286+
axis = sm.owner.op.axis
287+
288+
if axis is None:
289+
if x.type.ndim == 1:
290+
axis = 0
291+
else:
292+
# All dimensions are mixed, we can't lift the subtensor
293+
return None
294+
else:
295+
# Softmax currently only allows None or a single integer axis
296+
# Unlike CAReduce it does not normalize negative indices
297+
axis = normalize_axis_index(axis, sm.ndim)
298+
299+
[old_out] = node.outputs
300+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
301+
302+
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
303+
# If there are more dimensions being indexed, we can split them
304+
# And lift the non-axis indexes while keeping the axis index
305+
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
306+
if len(real_indices) > 1 and sm.type.ndim > 1:
307+
# Split the subtensor
308+
idx_to_keep = idx_tuple[axis]
309+
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
310+
311+
# Lift the non-axis indexes by calling the rewrite itself
312+
opt_sm = sm[idxs_to_lift]
313+
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
314+
copy_stack_trace([old_out, sm], opt_sm)
315+
316+
# Then reintroduce the axis index
317+
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
318+
idx_tuple, axis
319+
)
320+
new_axis = axis - ndim_reduced_left
321+
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
322+
new_out = opt_sm[idxs_to_keep]
323+
copy_stack_trace(old_out, new_out)
324+
return [new_out]
325+
326+
else:
327+
return None
328+
329+
# Index input to softmax
330+
x_sub = x[idx_tuple]
331+
332+
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
333+
axis -= len(
334+
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
335+
)
336+
337+
out = softmax(x_sub, axis=axis)
338+
copy_stack_trace(old_out, out)
339+
return [out]
340+
341+
246342
@register_canonicalize("shape_unsafe")
247343
@register_specialize("shape_unsafe")
248344
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
local_subtensor_shape_constant,
5252
)
5353
from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape
54+
from pytensor.tensor.special import softmax
5455
from pytensor.tensor.subtensor import Subtensor
5556

5657

@@ -213,6 +214,44 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn):
213214
)
214215

215216

217+
@pytest.mark.parametrize(
218+
"original_fn, expected_fn",
219+
[
220+
# Lift single index that does not ovelap with axis of softmax
221+
(lambda x: softmax(x, axis=1)[0], lambda x: softmax(x[0], axis=0)),
222+
(lambda x: softmax(x, axis=1)[1:], lambda x: softmax(x[1:], axis=1)),
223+
(lambda x: softmax(x, axis=0)[:, 0], lambda x: softmax(x[:, 0], axis=0)),
224+
(lambda x: softmax(x, axis=0)[:, 1:], lambda x: softmax(x[:, 1:], axis=0)),
225+
# Do nothing to single index over axis of softmax
226+
(lambda x: softmax(x, axis=0)[0], lambda x: softmax(x, axis=0)[0]),
227+
(lambda x: softmax(x, axis=1)[:, 1:], lambda x: softmax(x, axis=1)[:, 1:]),
228+
# Split indexing on axis of softmax
229+
(lambda x: softmax(x, axis=0)[1:, 0], lambda x: softmax(x[:, 0], axis=0)[1:]),
230+
(lambda x: softmax(x, axis=1)[1:, 0], lambda x: softmax(x[1:], axis=1)[:, 0]),
231+
(
232+
lambda x: softmax(x, axis=0)[0, :5:2],
233+
lambda x: softmax(x[:, :5:2], axis=0)[0],
234+
),
235+
(lambda x: softmax(x, axis=1)[0, :5:2], lambda x: softmax(x[0], axis=0)[:5:2]),
236+
],
237+
)
238+
def test_local_subtensor_of_softmax(original_fn, expected_fn):
239+
rng = np.random.default_rng(230)
240+
x = pt.matrix("x", shape=(5, 3))
241+
x_test = rng.normal(size=x.type.shape)
242+
243+
out = original_fn(x)
244+
expected_opt_out = expected_fn(x)
245+
opt_out = rewrite_graph(out)
246+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
247+
[expected_opt_out, opt_out], print_type=True
248+
)
249+
np.testing.assert_allclose(
250+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
251+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
252+
)
253+
254+
216255
def test_local_subtensor_of_unbroadcast():
217256
# Test that Subtensor(Unbroadcast(x)) gets optimized into
218257
# Unbroadcast(Subtensor(x)).

0 commit comments

Comments
 (0)