Skip to content

Commit 5936ab2

Browse files
committed
Refactor: Extract ellipsis expansion logic into helper function
1 parent 5988cec commit 5936ab2

File tree

2 files changed

+38
-39
lines changed

2 files changed

+38
-39
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.tensor import broadcast_to, join, moveaxis
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5-
from pytensor.xtensor.shape import Concat, Stack, Transpose
5+
from pytensor.xtensor.shape import Concat, Stack, Transpose, expand_ellipsis
66

77

88
@register_xcanonicalize
@@ -79,24 +79,7 @@ def lower_transpose(fgraph, node):
7979
# Determine the permutation of axes
8080
out_dims = node.op.dims
8181
in_dims = x.type.dims
82-
# Expand ellipsis if present
83-
if out_dims == () or out_dims == (...,):
84-
expanded_dims = tuple(reversed(in_dims))
85-
elif ... in out_dims:
86-
pre = []
87-
post = []
88-
found = False
89-
for d in out_dims:
90-
if d is ...:
91-
found = True
92-
elif not found:
93-
pre.append(d)
94-
else:
95-
post.append(d)
96-
middle = [d for d in in_dims if d not in pre + post]
97-
expanded_dims = tuple(pre + middle + post)
98-
else:
99-
expanded_dims = out_dims
82+
expanded_dims = expand_ellipsis(out_dims, in_dims)
10083
perm = tuple(in_dims.index(d) for d in expanded_dims)
10184
x_tensor = tensor_from_xtensor(x)
10285
x_tensor_transposed = x_tensor.transpose(perm)

pytensor/xtensor/shape.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,41 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7373
return y
7474

7575

76+
def expand_ellipsis(dims: tuple[str, ...], all_dims: tuple[str, ...]) -> tuple[str, ...]:
77+
"""Expand ellipsis in dimension permutation.
78+
79+
Parameters
80+
----------
81+
dims : tuple[str, ...]
82+
The dimension permutation, which may contain ellipsis
83+
all_dims : tuple[str, ...]
84+
All available dimensions
85+
86+
Returns
87+
-------
88+
tuple[str, ...]
89+
The expanded dimension permutation
90+
"""
91+
if dims == () or dims == (...,):
92+
return tuple(reversed(all_dims))
93+
94+
if ... not in dims:
95+
return dims
96+
97+
pre = []
98+
post = []
99+
found = False
100+
for d in dims:
101+
if d is ...:
102+
found = True
103+
elif not found:
104+
pre.append(d)
105+
else:
106+
post.append(d)
107+
middle = [d for d in all_dims if d not in pre + post]
108+
return tuple(pre + middle + post)
109+
110+
76111
class Transpose(XOp):
77112
__props__ = ("dims",)
78113

@@ -82,26 +117,7 @@ def __init__(self, dims: tuple[str, ...]):
82117

83118
def make_node(self, x):
84119
x = as_xtensor(x)
85-
# Allow ellipsis for full transpose
86-
if self.dims == () or self.dims == (...,):
87-
dims = tuple(reversed(x.type.dims))
88-
else:
89-
# Expand ellipsis if present
90-
if ... in self.dims:
91-
pre = []
92-
post = []
93-
found = False
94-
for d in self.dims:
95-
if d is ...:
96-
found = True
97-
elif not found:
98-
pre.append(d)
99-
else:
100-
post.append(d)
101-
middle = [d for d in x.type.dims if d not in pre + post]
102-
dims = tuple(pre + middle + post)
103-
else:
104-
dims = self.dims
120+
dims = expand_ellipsis(self.dims, x.type.dims)
105121
if set(dims) != set(x.type.dims):
106122
raise ValueError(f"Transpose dims {dims} must match {x.type.dims}")
107123
output = xtensor(

0 commit comments

Comments
 (0)