@@ -73,6 +73,41 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7373 return y
7474
7575
76+ def expand_ellipsis (dims : tuple [str , ...], all_dims : tuple [str , ...]) -> tuple [str , ...]:
77+ """Expand ellipsis in dimension permutation.
78+
79+ Parameters
80+ ----------
81+ dims : tuple[str, ...]
82+ The dimension permutation, which may contain ellipsis
83+ all_dims : tuple[str, ...]
84+ All available dimensions
85+
86+ Returns
87+ -------
88+ tuple[str, ...]
89+ The expanded dimension permutation
90+ """
91+ if dims == () or dims == (...,):
92+ return tuple (reversed (all_dims ))
93+
94+ if ... not in dims :
95+ return dims
96+
97+ pre = []
98+ post = []
99+ found = False
100+ for d in dims :
101+ if d is ...:
102+ found = True
103+ elif not found :
104+ pre .append (d )
105+ else :
106+ post .append (d )
107+ middle = [d for d in all_dims if d not in pre + post ]
108+ return tuple (pre + middle + post )
109+
110+
76111class Transpose (XOp ):
77112 __props__ = ("dims" ,)
78113
@@ -82,26 +117,7 @@ def __init__(self, dims: tuple[str, ...]):
82117
83118 def make_node (self , x ):
84119 x = as_xtensor (x )
85- # Allow ellipsis for full transpose
86- if self .dims == () or self .dims == (...,):
87- dims = tuple (reversed (x .type .dims ))
88- else :
89- # Expand ellipsis if present
90- if ... in self .dims :
91- pre = []
92- post = []
93- found = False
94- for d in self .dims :
95- if d is ...:
96- found = True
97- elif not found :
98- pre .append (d )
99- else :
100- post .append (d )
101- middle = [d for d in x .type .dims if d not in pre + post ]
102- dims = tuple (pre + middle + post )
103- else :
104- dims = self .dims
120+ dims = expand_ellipsis (self .dims , x .type .dims )
105121 if set (dims ) != set (x .type .dims ):
106122 raise ValueError (f"Transpose dims { dims } must match { x .type .dims } " )
107123 output = xtensor (
0 commit comments