|
1 | | -from collections.abc import Iterable |
| 1 | +from collections.abc import Iterable, Sequence |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 |
|
|
17 | 17 | ) |
18 | 18 | from pytensor.tensor.elemwise import DimShuffle, Elemwise |
19 | 19 | from pytensor.tensor.exceptions import NotScalarConstantError |
| 20 | +from pytensor.tensor.extra_ops import squeeze |
20 | 21 | from pytensor.tensor.math import Dot, ceil_intdiv, dot |
21 | 22 | from pytensor.tensor.rewriting.basic import ( |
22 | 23 | register_canonicalize, |
23 | 24 | register_specialize, |
24 | 25 | register_stabilize, |
25 | 26 | ) |
| 27 | +from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift |
26 | 28 | from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless |
27 | 29 | from pytensor.tensor.shape import ( |
28 | 30 | Shape, |
|
44 | 46 | from pytensor.tensor.type_other import SliceType |
45 | 47 |
|
46 | 48 |
|
| 49 | +def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]: |
| 50 | + # Inputs can be slice or integer indexes |
| 51 | + # Slices keep the dimensions, integers collapse them |
| 52 | + return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice)) |
| 53 | + |
| 54 | + |
47 | 55 | @register_canonicalize |
48 | 56 | @register_stabilize |
49 | 57 | @register_specialize |
@@ -280,6 +288,54 @@ def local_subtensor_of_expand_dims(fgraph, node): |
280 | 288 | return [out] |
281 | 289 |
|
282 | 290 |
|
| 291 | +@register_canonicalize |
| 292 | +@register_specialize |
| 293 | +@node_rewriter([Subtensor]) |
| 294 | +def local_subtensor_of_transpose(fgraph, node): |
| 295 | + """Lift a Subtensor through a DimShuffle that only transposes. |
| 296 | +
|
| 297 | + transpose(x, (1, 0, 2))[i:, j:, k:] -> transpose(x[j:, i:, k:], (1, 0, 2)) |
| 298 | + """ |
| 299 | + ds, *idx = node.inputs |
| 300 | + |
| 301 | + if not (ds.owner and isinstance(ds.owner.op, DimShuffle)): |
| 302 | + return None |
| 303 | + |
| 304 | + ds_op = ds.owner.op |
| 305 | + if not ds_op.is_transpose: |
| 306 | + return None |
| 307 | + |
| 308 | + transposition = ds_op.transposition |
| 309 | + [x] = ds.owner.inputs |
| 310 | + |
| 311 | + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) |
| 312 | + |
| 313 | + # Apply the transposition to the indexes |
| 314 | + n_implicit_idxs = x.type.ndim - len(idx_tuple) |
| 315 | + idx_tuple = idx_tuple + (slice(None),) * n_implicit_idxs |
| 316 | + new_idxs = [idx_tuple[i] for i in transposition] |
| 317 | + new_x = x[tuple(new_idxs)] |
| 318 | + |
| 319 | + # Reintroduce any dims dropped by indexing so the original transpose still works |
| 320 | + dims_dropped_by_new_idx = _dims_dropped_by_basic_index(new_idxs) |
| 321 | + if dims_dropped_by_new_idx: |
| 322 | + new_x = expand_dims(new_x, axis=dims_dropped_by_new_idx) |
| 323 | + |
| 324 | + # Apply the transpose |
| 325 | + new_out = ds_op(new_x) |
| 326 | + |
| 327 | + # Squeeze dims again now that the transpose is done |
| 328 | + if dims_dropped_by_new_idx: |
| 329 | + dims_dropped_by_original_idx = _dims_dropped_by_basic_index(idx_tuple) |
| 330 | + new_out = squeeze(new_out, axis=dims_dropped_by_original_idx) |
| 331 | + |
| 332 | + # Cleanup consecutive expand_dims / transpose / squeeze (if any) |
| 333 | + if dims_dropped_by_new_idx: |
| 334 | + [new_out] = local_dimshuffle_lift.transform(fgraph, new_out.owner) |
| 335 | + |
| 336 | + return [new_out] |
| 337 | + |
| 338 | + |
283 | 339 | @register_infer_shape |
284 | 340 | @register_useless |
285 | 341 | @register_canonicalize |
|
0 commit comments