@@ -189,7 +189,7 @@ def make_node(self, x):
189
189
190
190
def transpose (
191
191
x ,
192
- * dims : str | EllipsisType ,
192
+ * dim : str | EllipsisType ,
193
193
missing_dims : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
194
194
):
195
195
"""Transpose dimensions of the tensor.
@@ -198,7 +198,7 @@ def transpose(
198
198
----------
199
199
x : XTensorVariable
200
200
Input tensor to transpose.
201
- *dims : str
201
+ *dim : str
202
202
Dimensions to transpose to. Can include ellipsis (...) to represent
203
203
remaining dimensions in their original order.
204
204
missing_dims : {"raise", "warn", "ignore"}, optional
@@ -220,7 +220,7 @@ def transpose(
220
220
# Validate dimensions
221
221
x = as_xtensor (x )
222
222
x_dims = x .type .dims
223
- invalid_dims = set (dims ) - {..., * x_dims }
223
+ invalid_dims = set (dim ) - {..., * x_dims }
224
224
if invalid_dims :
225
225
if missing_dims != "ignore" :
226
226
msg = f"Dimensions { invalid_dims } do not exist. Expected one or more of: { x_dims } "
@@ -229,21 +229,27 @@ def transpose(
229
229
else :
230
230
warnings .warn (msg )
231
231
# Handle missing dimensions if not raising
232
- dims = tuple (d for d in dims if d in x_dims or d is ...)
233
-
234
- if dims == () or dims == (...,):
235
- dims = tuple (reversed (x_dims ))
236
- elif ... in dims :
237
- if dims .count (...) > 1 :
232
+ dim = tuple (d for d in dim if d in x_dims or d is ...)
233
+
234
+ if dim == ():
235
+ dim = tuple (reversed (x_dims ))
236
+ elif dim == (...,):
237
+ dim = x_dims
238
+ elif ... in dim :
239
+ if dim .count (...) > 1 :
238
240
raise ValueError ("Ellipsis (...) can only appear once in the dimensions" )
239
241
# Handle ellipsis expansion
240
- ellipsis_idx = dims .index (...)
241
- pre = dims [:ellipsis_idx ]
242
- post = dims [ellipsis_idx + 1 :]
242
+ ellipsis_idx = dim .index (...)
243
+ pre = dim [:ellipsis_idx ]
244
+ post = dim [ellipsis_idx + 1 :]
243
245
middle = [d for d in x_dims if d not in pre + post ]
244
- dims = (* pre , * middle , * post )
246
+ dim = (* pre , * middle , * post )
247
+
248
+ if dim == x_dims :
249
+ # No-op transpose
250
+ return x
245
251
246
- return Transpose (typing .cast (tuple [str ], dims ))(x )
252
+ return Transpose (dims = typing .cast (tuple [str ], dim ))(x )
247
253
248
254
249
255
class Concat (XOp ):
0 commit comments