Skip to content

Commit 605e733

Browse files
committed
Lift Subtensor over transpose
1 parent f5a13f2 commit 605e733

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Iterable
1+
from collections.abc import Iterable, Sequence
22

33
import numpy as np
44

@@ -17,12 +17,14 @@
1717
)
1818
from pytensor.tensor.elemwise import DimShuffle, Elemwise
1919
from pytensor.tensor.exceptions import NotScalarConstantError
20+
from pytensor.tensor.extra_ops import squeeze
2021
from pytensor.tensor.math import Dot, ceil_intdiv, dot
2122
from pytensor.tensor.rewriting.basic import (
2223
register_canonicalize,
2324
register_specialize,
2425
register_stabilize,
2526
)
27+
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
2628
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
2729
from pytensor.tensor.shape import (
2830
Shape,
@@ -44,6 +46,12 @@
4446
from pytensor.tensor.type_other import SliceType
4547

4648

49+
def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]:
50+
# Inputs can be slice or integer indexes
51+
# Slices keep the dimensions, integers collapse them
52+
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
53+
54+
4755
@register_canonicalize
4856
@register_stabilize
4957
@register_specialize
@@ -280,6 +288,55 @@ def local_subtensor_of_expand_dims(fgraph, node):
280288
return [out]
281289

282290

291+
@register_canonicalize
292+
@register_specialize
293+
@node_rewriter([Subtensor])
294+
def local_subtensor_of_transpose(fgraph, node):
295+
"""Lift a Subtensor through a DimShuffle that only transposes.
296+
297+
transpose(x, (1, 0, 2))[i:, j:, k:] -> transpose(x[j:, i:, k:], (1, 0, 2))
298+
"""
299+
ds, *idx = node.inputs
300+
301+
if not (ds.owner and isinstance(ds.owner.op, DimShuffle)):
302+
return None
303+
304+
ds_op = ds.owner.op
305+
if not ds_op.is_transpose:
306+
return None
307+
308+
transposition = ds_op.transposition
309+
[x] = ds.owner.inputs
310+
311+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
312+
313+
# Apply the transposition to the indexes
314+
ndim = x.type.ndim
315+
n_implicit_idxs = ndim - len(idx_tuple)
316+
idx_tuple = idx_tuple + (slice(None),) * n_implicit_idxs
317+
new_idxs = [idx_tuple[transposition.index(i)] for i in range(ndim)]
318+
new_x = x[tuple(new_idxs)]
319+
320+
# Reintroduce any dims dropped by indexing so the original transpose still works
321+
dims_dropped_by_new_idx = _dims_dropped_by_basic_index(new_idxs)
322+
if dims_dropped_by_new_idx:
323+
new_x = expand_dims(new_x, axis=dims_dropped_by_new_idx)
324+
325+
# Apply the transpose
326+
new_out = ds_op(new_x)
327+
328+
# Squeeze dims again now that the transpose is done
329+
if dims_dropped_by_new_idx:
330+
dims_dropped_by_original_idx = _dims_dropped_by_basic_index(idx_tuple)
331+
new_out = squeeze(new_out, axis=dims_dropped_by_original_idx)
332+
333+
# Cleanup consecutive expand_dims / transpose / squeeze (if any)
334+
if dims_dropped_by_new_idx:
335+
[new_out] = local_dimshuffle_lift.transform(fgraph, new_out.owner)
336+
337+
return [new_out]
338+
339+
283340
@register_infer_shape
284341
@register_useless
285342
@register_canonicalize

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
316316

317317
out = original_fn(x)
318318
expected_opt_out = expected_fn(x)
319-
opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"])
319+
opt_out = rewrite_graph(out)
320320
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
321321
[opt_out, expected_opt_out], print_type=True
322322
)
@@ -326,6 +326,35 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
326326
)
327327

328328

329+
@pytest.mark.parametrize(
330+
"original_fn, expected_fn",
331+
[
332+
(lambda x: x.transpose(2, 1, 0)[0], lambda x: x[:, :, 0].transpose(1, 0)),
333+
(lambda x: x.transpose(2, 1, 0)[:, :, 1:], lambda x: x[1:].transpose(2, 1, 0)),
334+
(
335+
lambda x: x.transpose(2, 1, 0)[0, :1, 1:],
336+
lambda x: x[1:, :1, 0].transpose(1, 0),
337+
),
338+
(lambda x: x.transpose(2, 1, 0)[0, :1, 1], lambda x: x[1, :1, 0]),
339+
],
340+
)
341+
def test_local_subtensor_of_transpose(original_fn, expected_fn):
342+
rng = np.random.default_rng(232)
343+
x = tensor("x", shape=(7, 5, 3))
344+
x_test = rng.normal(size=x.type.shape)
345+
346+
out = original_fn(x)
347+
expected_opt_out = expected_fn(x)
348+
opt_out = rewrite_graph(out)
349+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
350+
[expected_opt_out, opt_out], print_type=True
351+
)
352+
np.testing.assert_allclose(
353+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
354+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
355+
)
356+
357+
329358
def test_local_subtensor_of_alloc():
330359
# DebugMode should detect if something goes wrong.
331360
# test shape combination of odd and event shape.

0 commit comments

Comments
 (0)