Skip to content

Commit 5c97e9f

Browse files
committed
Lift Subtensor over transpose
1 parent 2ad449b commit 5c97e9f

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Iterable
1+
from collections.abc import Iterable, Sequence
22

33
import numpy as np
44

@@ -17,12 +17,14 @@
1717
)
1818
from pytensor.tensor.elemwise import DimShuffle, Elemwise
1919
from pytensor.tensor.exceptions import NotScalarConstantError
20+
from pytensor.tensor.extra_ops import squeeze
2021
from pytensor.tensor.math import Dot, ceil_intdiv, dot
2122
from pytensor.tensor.rewriting.basic import (
2223
register_canonicalize,
2324
register_specialize,
2425
register_stabilize,
2526
)
27+
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
2628
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
2729
from pytensor.tensor.shape import (
2830
Shape,
@@ -44,6 +46,12 @@
4446
from pytensor.tensor.type_other import SliceType
4547

4648

49+
def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]:
50+
# Inputs can be slice or integer indexes
51+
# Slices keep the dimensions, integers collapse them
52+
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
53+
54+
4755
@register_canonicalize
4856
@register_stabilize
4957
@register_specialize
@@ -280,6 +288,54 @@ def local_subtensor_of_expand_dims(fgraph, node):
280288
return [out]
281289

282290

291+
@register_canonicalize
292+
@register_specialize
293+
@node_rewriter([Subtensor])
294+
def local_subtensor_of_transpose(fgraph, node):
295+
"""Lift a Subtensor through a DimShuffle that only transposes.
296+
297+
transpose(x, (1, 0, 2))[i:, j:, k:] -> transpose(x[j:, i:, k:], (1, 0, 2))
298+
"""
299+
ds, *idx = node.inputs
300+
301+
if not (ds.owner and isinstance(ds.owner.op, DimShuffle)):
302+
return None
303+
304+
ds_op = ds.owner.op
305+
if not ds_op.is_transpose:
306+
return None
307+
308+
transposition = ds_op.transposition
309+
[x] = ds.owner.inputs
310+
311+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
312+
313+
# Apply the transposition to the indexes
314+
n_implicit_idxs = x.type.ndim - len(idx_tuple)
315+
idx_tuple = idx_tuple + (slice(None),) * n_implicit_idxs
316+
new_idxs = [idx_tuple[i] for i in transposition]
317+
new_x = x[tuple(new_idxs)]
318+
319+
# Reintroduce any dims dropped by indexing so the original transpose still works
320+
dims_dropped_by_new_idx = _dims_dropped_by_basic_index(new_idxs)
321+
if dims_dropped_by_new_idx:
322+
new_x = expand_dims(new_x, axis=dims_dropped_by_new_idx)
323+
324+
# Apply the transpose
325+
new_out = ds_op(new_x)
326+
327+
# Squeeze dims again now that the transpose is done
328+
if dims_dropped_by_new_idx:
329+
dims_dropped_by_original_idx = _dims_dropped_by_basic_index(idx_tuple)
330+
new_out = squeeze(new_out, axis=dims_dropped_by_original_idx)
331+
332+
# Cleanup consecutive expand_dims / transpose / squeeze (if any)
333+
if dims_dropped_by_new_idx:
334+
[new_out] = local_dimshuffle_lift.transform(fgraph, new_out.owner)
335+
336+
return [new_out]
337+
338+
283339
@register_infer_shape
284340
@register_useless
285341
@register_canonicalize

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
310310

311311
out = original_fn(x)
312312
expected_opt_out = expected_fn(x)
313-
opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"])
313+
opt_out = rewrite_graph(out)
314314
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
315315
[opt_out, expected_opt_out], print_type=True
316316
)
@@ -320,6 +320,35 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
320320
)
321321

322322

323+
@pytest.mark.parametrize(
324+
"original_fn, expected_fn",
325+
[
326+
(lambda x: x.transpose(2, 1, 0)[0], lambda x: x[:, :, 0].transpose(1, 0)),
327+
(lambda x: x.transpose(2, 1, 0)[:, :, 1:], lambda x: x[1:].transpose(2, 1, 0)),
328+
(
329+
lambda x: x.transpose(2, 1, 0)[0, :1, 1:],
330+
lambda x: x[1:, :1, 0].transpose(1, 0),
331+
),
332+
(lambda x: x.transpose(2, 1, 0)[0, :1, 1], lambda x: x[1, :1, 0]),
333+
],
334+
)
335+
def test_local_subtensor_of_transpose(original_fn, expected_fn):
336+
rng = np.random.default_rng(232)
337+
x = tensor("x", shape=(7, 5, 3))
338+
x_test = rng.normal(size=x.type.shape)
339+
340+
out = original_fn(x)
341+
expected_opt_out = expected_fn(x)
342+
opt_out = rewrite_graph(out)
343+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
344+
[expected_opt_out, opt_out], print_type=True
345+
)
346+
np.testing.assert_allclose(
347+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
348+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
349+
)
350+
351+
323352
def test_local_subtensor_of_alloc():
324353
# DebugMode should detect if something goes wrong.
325354
# test shape combination of odd and event shape.

0 commit comments

Comments
 (0)