Skip to content

Commit ed06277

Browse files
committed
Lift Subtensor over Softmax
1 parent 0f3edf9 commit ed06277

File tree

2 files changed

+134
-2
lines changed

2 files changed

+134
-2
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from collections.abc import Iterable
1+
from collections.abc import Iterable, Sequence
22

33
import numpy as np
4-
from numpy.core.numeric import normalize_axis_tuple
4+
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
55

66
from pytensor import Variable
77
from pytensor.graph import Constant, node_rewriter
@@ -34,6 +34,7 @@
3434
specify_shape,
3535
unbroadcast,
3636
)
37+
from pytensor.tensor.special import Softmax, softmax
3738
from pytensor.tensor.subtensor import (
3839
AdvancedSubtensor1,
3940
Subtensor,
@@ -48,9 +49,23 @@
4849

4950

5051
def _dims_dropped_by_basic_index(idxs) -> tuple[int, ...]:
52+
# Inputs can be slice or integer indexes
53+
# Slices keep the dimensions, integers collapse them
5154
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
5255

5356

57+
def _ndim_dropped_left_of_axis_by_basic_index(idxs, axis: int) -> int:
58+
return len(_dims_dropped_by_basic_index(idxs[:axis]))
59+
60+
61+
def _axis_is_indexed_by_basic_index(
62+
idxs: tuple[Variable], axis: int | Sequence[int]
63+
) -> bool:
64+
if isinstance(axis, int):
65+
axis = (axis,)
66+
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
67+
68+
5469
@register_canonicalize
5570
@register_stabilize
5671
@register_specialize
@@ -241,6 +256,84 @@ def local_subtensor_of_reduce(fgraph, node):
241256
return [out]
242257

243258

259+
@register_canonicalize
260+
@register_specialize
261+
@node_rewriter([Subtensor])
262+
def local_subtensor_of_softmax(fgraph, node):
263+
"""Lift a Subtensor through a Softmax.
264+
265+
softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
266+
softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
267+
268+
If part of the indexing acts on the axis of reduction, we split it
269+
softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
270+
271+
"""
272+
sm, *idx = node.inputs
273+
274+
if not (sm.owner and isinstance(sm.owner.op, Softmax)):
275+
return None
276+
277+
if len(fgraph.clients[sm]) > 1:
278+
return None
279+
280+
[x] = sm.owner.inputs
281+
axis = sm.owner.op.axis
282+
283+
if axis is None:
284+
if x.type.ndim == 1:
285+
axis = 0
286+
else:
287+
# All dimensions are mixed, we can't lift the subtensor
288+
return None
289+
else:
290+
# Softmax currently only allows None or a single integer axis
291+
# Unlike CAReduce it does not normalize negative indices
292+
axis = normalize_axis_index(axis, sm.ndim)
293+
294+
[old_out] = node.outputs
295+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
296+
297+
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
298+
# If there are more dimensions being indexed, we can split them
299+
# And lift the non-axis indexes while keeping the axis index
300+
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
301+
if len(real_indices) > 1 and sm.type.ndim > 1:
302+
# Split the subtensor
303+
idx_to_keep = idx_tuple[axis]
304+
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
305+
306+
# Lift the non-axis indexes by calling the rewrite itself
307+
opt_sm = sm[idxs_to_lift]
308+
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
309+
copy_stack_trace([old_out, sm], opt_sm)
310+
311+
# Then reintroduce the axis index
312+
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
313+
idx_tuple, axis
314+
)
315+
new_axis = axis - ndim_reduced_left
316+
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
317+
new_out = opt_sm[idxs_to_keep]
318+
copy_stack_trace(old_out, new_out)
319+
return [new_out]
320+
321+
else:
322+
return None
323+
324+
# Index input to softmax
325+
x_sub = x[idx_tuple]
326+
327+
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
328+
axis -= len(
329+
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
330+
)
331+
332+
out = softmax(x_sub, axis=axis)
333+
copy_stack_trace(old_out, out)
334+
return [out]
335+
336+
244337
@register_canonicalize("shape_unsafe")
245338
@register_specialize("shape_unsafe")
246339
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
local_subtensor_shape_constant,
5252
)
5353
from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape
54+
from pytensor.tensor.special import softmax
5455
from pytensor.tensor.subtensor import Subtensor
5556

5657

@@ -213,6 +214,44 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn):
213214
)
214215

215216

217+
@pytest.mark.parametrize(
218+
"original_fn, expected_fn",
219+
[
220+
# Lift single index that does not ovelap with axis of softmax
221+
(lambda x: softmax(x, axis=1)[0], lambda x: softmax(x[0], axis=0)),
222+
(lambda x: softmax(x, axis=1)[1:], lambda x: softmax(x[1:], axis=1)),
223+
(lambda x: softmax(x, axis=0)[:, 0], lambda x: softmax(x[:, 0], axis=0)),
224+
(lambda x: softmax(x, axis=0)[:, 1:], lambda x: softmax(x[:, 1:], axis=0)),
225+
# Do nothing to single index over axis of softmax
226+
(lambda x: softmax(x, axis=0)[0], lambda x: softmax(x, axis=0)[0]),
227+
(lambda x: softmax(x, axis=1)[:, 1:], lambda x: softmax(x, axis=1)[:, 1:]),
228+
# Split indexing on axis of softmax
229+
(lambda x: softmax(x, axis=0)[1:, 0], lambda x: softmax(x[:, 0], axis=0)[1:]),
230+
(lambda x: softmax(x, axis=1)[1:, 0], lambda x: softmax(x[1:], axis=1)[:, 0]),
231+
(
232+
lambda x: softmax(x, axis=0)[0, :5:2],
233+
lambda x: softmax(x[:, :5:2], axis=0)[0],
234+
),
235+
(lambda x: softmax(x, axis=1)[0, :5:2], lambda x: softmax(x[0], axis=0)[:5:2]),
236+
],
237+
)
238+
def test_local_subtensor_of_softmax(original_fn, expected_fn):
239+
rng = np.random.default_rng(230)
240+
x = pt.matrix("x", shape=(5, 3))
241+
x_test = rng.normal(size=x.type.shape)
242+
243+
out = original_fn(x)
244+
expected_opt_out = expected_fn(x)
245+
opt_out = rewrite_graph(out)
246+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
247+
[expected_opt_out, opt_out], print_type=True
248+
)
249+
np.testing.assert_allclose(
250+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
251+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
252+
)
253+
254+
216255
def test_local_subtensor_of_unbroadcast():
217256
# Test that Subtensor(Unbroadcast(x)) gets optimized into
218257
# Unbroadcast(Subtensor(x)).

0 commit comments

Comments
 (0)