Skip to content

Commit c20a47b

Browse files
committed
Lift Subtensor over Softmax
1 parent 0960e05 commit c20a47b

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pytensor import Variable
66
from pytensor.graph import Constant, node_rewriter
77
from pytensor.graph.rewriting.basic import copy_stack_trace
8-
from pytensor.npy_2_compat import normalize_axis_tuple
8+
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
99
from pytensor.scalar import basic as ps
1010
from pytensor.tensor.basic import (
1111
Alloc,
@@ -34,6 +34,7 @@
3434
specify_shape,
3535
unbroadcast,
3636
)
37+
from pytensor.tensor.special import Softmax, softmax
3738
from pytensor.tensor.subtensor import (
3839
AdvancedSubtensor1,
3940
Subtensor,
@@ -53,6 +54,20 @@ def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]
5354
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
5455

5556

57+
def _ndim_dropped_left_of_axis_by_basic_index(
58+
idxs: Sequence[slice | int], axis: int
59+
) -> int:
60+
return len(_dims_dropped_by_basic_index(idxs[:axis]))
61+
62+
63+
def _axis_is_indexed_by_basic_index(
64+
idxs: Sequence[slice | int], axis: int | Sequence[int]
65+
) -> bool:
66+
if isinstance(axis, int):
67+
axis = (axis,)
68+
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
69+
70+
5671
@register_canonicalize
5772
@register_stabilize
5873
@register_specialize
@@ -243,6 +258,84 @@ def local_subtensor_of_reduce(fgraph, node):
243258
return [out]
244259

245260

261+
@register_canonicalize
262+
@register_specialize
263+
@node_rewriter([Subtensor])
264+
def local_subtensor_of_softmax(fgraph, node):
265+
"""Lift a Subtensor through a Softmax.
266+
267+
softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
268+
softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
269+
270+
If part of the indexing acts on the axis of reduction, we split it
271+
softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
272+
273+
"""
274+
sm, *idx = node.inputs
275+
276+
if not (sm.owner and isinstance(sm.owner.op, Softmax)):
277+
return None
278+
279+
if len(fgraph.clients[sm]) > 1:
280+
return None
281+
282+
[x] = sm.owner.inputs
283+
axis = sm.owner.op.axis
284+
285+
if axis is None:
286+
if x.type.ndim == 1:
287+
axis = 0
288+
else:
289+
# All dimensions are mixed, we can't lift the subtensor
290+
return None
291+
else:
292+
# Softmax currently only allows None or a single integer axis
293+
# Unlike CAReduce it does not normalize negative indices
294+
axis = normalize_axis_index(axis, sm.ndim)
295+
296+
[old_out] = node.outputs
297+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
298+
299+
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
300+
# If there are more dimensions being indexed, we can split them
301+
# And lift the non-axis indexes while keeping the axis index
302+
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
303+
if len(real_indices) > 1 and sm.type.ndim > 1:
304+
# Split the subtensor
305+
idx_to_keep = idx_tuple[axis]
306+
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
307+
308+
# Lift the non-axis indexes by calling the rewrite itself
309+
opt_sm = sm[idxs_to_lift]
310+
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
311+
copy_stack_trace([old_out, sm], opt_sm)
312+
313+
# Then reintroduce the axis index
314+
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
315+
idx_tuple, axis
316+
)
317+
new_axis = axis - ndim_reduced_left
318+
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
319+
new_out = opt_sm[idxs_to_keep]
320+
copy_stack_trace(old_out, new_out)
321+
return [new_out]
322+
323+
else:
324+
return None
325+
326+
# Index input to softmax
327+
x_sub = x[idx_tuple]
328+
329+
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
330+
axis -= len(
331+
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
332+
)
333+
334+
out = softmax(x_sub, axis=axis)
335+
copy_stack_trace(old_out, out)
336+
return [out]
337+
338+
246339
@register_canonicalize("shape_unsafe")
247340
@register_specialize("shape_unsafe")
248341
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
local_subtensor_shape_constant,
4747
)
4848
from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape
49+
from pytensor.tensor.special import softmax
4950
from pytensor.tensor.subtensor import Subtensor
5051

5152

@@ -212,6 +213,44 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn):
212213
)
213214

214215

216+
@pytest.mark.parametrize(
217+
"original_fn, expected_fn",
218+
[
219+
# Lift single index that does not ovelap with axis of softmax
220+
(lambda x: softmax(x, axis=1)[0], lambda x: softmax(x[0], axis=0)),
221+
(lambda x: softmax(x, axis=1)[1:], lambda x: softmax(x[1:], axis=1)),
222+
(lambda x: softmax(x, axis=0)[:, 0], lambda x: softmax(x[:, 0], axis=0)),
223+
(lambda x: softmax(x, axis=0)[:, 1:], lambda x: softmax(x[:, 1:], axis=0)),
224+
# Do nothing to single index over axis of softmax
225+
(lambda x: softmax(x, axis=0)[0], lambda x: softmax(x, axis=0)[0]),
226+
(lambda x: softmax(x, axis=1)[:, 1:], lambda x: softmax(x, axis=1)[:, 1:]),
227+
# Split indexing on axis of softmax
228+
(lambda x: softmax(x, axis=0)[1:, 0], lambda x: softmax(x[:, 0], axis=0)[1:]),
229+
(lambda x: softmax(x, axis=1)[1:, 0], lambda x: softmax(x[1:], axis=1)[:, 0]),
230+
(
231+
lambda x: softmax(x, axis=0)[0, :5:2],
232+
lambda x: softmax(x[:, :5:2], axis=0)[0],
233+
),
234+
(lambda x: softmax(x, axis=1)[0, :5:2], lambda x: softmax(x[0], axis=0)[:5:2]),
235+
],
236+
)
237+
def test_local_subtensor_of_softmax(original_fn, expected_fn):
238+
rng = np.random.default_rng(230)
239+
x = pt.matrix("x", shape=(5, 3))
240+
x_test = rng.normal(size=x.type.shape)
241+
242+
out = original_fn(x)
243+
expected_opt_out = expected_fn(x)
244+
opt_out = rewrite_graph(out)
245+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
246+
[expected_opt_out, opt_out], print_type=True
247+
)
248+
np.testing.assert_allclose(
249+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
250+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
251+
)
252+
253+
215254
def test_local_subtensor_of_unbroadcast():
216255
# Test that Subtensor(Unbroadcast(x)) gets optimized into
217256
# Unbroadcast(Subtensor(x)).

0 commit comments

Comments
 (0)