Skip to content

Commit f8cfe6a

Browse files
committed
Lift Subtensor over transpose
1 parent 5fc9e06 commit f8cfe6a

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
dot,
1515
get_underlying_scalar_constant_value,
1616
specify_shape,
17+
squeeze,
1718
)
1819
from pytensor.tensor.basic import Alloc, MakeVector, expand_dims, register_infer_shape
1920
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -24,6 +25,7 @@
2425
register_specialize,
2526
register_stabilize,
2627
)
28+
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
2729
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
2830
from pytensor.tensor.shape import Shape, SpecifyShape, Unbroadcast, unbroadcast
2931
from pytensor.tensor.subtensor import (
@@ -38,6 +40,10 @@
3840
from pytensor.tensor.type_other import SliceType
3941

4042

43+
def _dims_dropped_by_basic_index(idxs) -> tuple[int, ...]:
44+
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
45+
46+
4147
@register_canonicalize
4248
@register_stabilize
4349
@register_specialize
@@ -274,6 +280,54 @@ def local_subtensor_of_expand_dims(fgraph, node):
274280
return [out]
275281

276282

283+
@register_canonicalize
284+
@register_specialize
285+
@node_rewriter([Subtensor])
286+
def local_subtensor_of_transpose(fgraph, node):
287+
"""Lift a Subtensor through a DimShuffle that only transposes.
288+
289+
transpose(x, (1, 0, 2))[i:, j:, k:] -> transpose(x[j:, i:, k:], (1, 0, 2))
290+
"""
291+
ds, *idx = node.inputs
292+
293+
if not (ds.owner and isinstance(ds.owner.op, DimShuffle)):
294+
return None
295+
296+
ds_op = ds.owner.op
297+
if not ds_op.is_transpose:
298+
return None
299+
300+
transposition = ds_op.transposition
301+
[x] = ds.owner.inputs
302+
303+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
304+
305+
# Apply the transposition to the indexes
306+
n_implicit_idxs = x.type.ndim - len(idx_tuple)
307+
idx_tuple = idx_tuple + (slice(None),) * n_implicit_idxs
308+
new_idxs = [idx_tuple[i] for i in transposition]
309+
new_x = x[tuple(new_idxs)]
310+
311+
# Reintroduce any dims dropped by indexing so the original transpose still works
312+
dims_dropped_by_new_idx = _dims_dropped_by_basic_index(new_idxs)
313+
if dims_dropped_by_new_idx:
314+
new_x = expand_dims(new_x, axis=dims_dropped_by_new_idx)
315+
316+
# Apply the transpose
317+
new_out = ds_op(new_x)
318+
319+
# Squeeze dims again now that the transpose is done
320+
if dims_dropped_by_new_idx:
321+
dims_dropped_by_original_idx = _dims_dropped_by_basic_index(idx_tuple)
322+
new_out = squeeze(new_out, axis=dims_dropped_by_original_idx)
323+
324+
# Cleanup consecutive expand_dims / transpose / squeeze (if any)
325+
if dims_dropped_by_new_idx:
326+
[new_out] = local_dimshuffle_lift.transform(fgraph, new_out.owner)
327+
328+
return [new_out]
329+
330+
277331
@register_infer_shape
278332
@register_useless
279333
@register_canonicalize

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
310310

311311
out = original_fn(x)
312312
expected_opt_out = expected_fn(x)
313-
opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"])
313+
opt_out = rewrite_graph(out)
314314
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
315315
[opt_out, expected_opt_out], print_type=True
316316
)
@@ -320,6 +320,35 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
320320
)
321321

322322

323+
@pytest.mark.parametrize(
324+
"original_fn, expected_fn",
325+
[
326+
(lambda x: x.transpose(2, 1, 0)[0], lambda x: x[:, :, 0].transpose(1, 0)),
327+
(lambda x: x.transpose(2, 1, 0)[:, :, 1:], lambda x: x[1:].transpose(2, 1, 0)),
328+
(
329+
lambda x: x.transpose(2, 1, 0)[0, :1, 1:],
330+
lambda x: x[1:, :1, 0].transpose(1, 0),
331+
),
332+
(lambda x: x.transpose(2, 1, 0)[0, :1, 1], lambda x: x[1, :1, 0]),
333+
],
334+
)
335+
def test_local_subtensor_of_transpose(original_fn, expected_fn):
336+
rng = np.random.default_rng(232)
337+
x = tensor("x", shape=(7, 5, 3))
338+
x_test = rng.normal(size=x.type.shape)
339+
340+
out = original_fn(x)
341+
expected_opt_out = expected_fn(x)
342+
opt_out = rewrite_graph(out)
343+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
344+
[expected_opt_out, opt_out], print_type=True
345+
)
346+
np.testing.assert_allclose(
347+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
348+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
349+
)
350+
351+
323352
def test_local_subtensor_of_alloc():
324353
# DebugMode should detect if something goes wrong.
325354
# test shape combination of odd and event shape.

0 commit comments

Comments
 (0)