diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 4868e6e4f7..bb528fcf26 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -189,7 +189,7 @@ def make_node(self, x): def transpose( x, - *dims: str | EllipsisType, + *dim: str | EllipsisType, missing_dims: Literal["raise", "warn", "ignore"] = "raise", ): """Transpose dimensions of the tensor. @@ -198,7 +198,7 @@ def transpose( ---------- x : XTensorVariable Input tensor to transpose. - *dims : str + *dim : str Dimensions to transpose to. Can include ellipsis (...) to represent remaining dimensions in their original order. missing_dims : {"raise", "warn", "ignore"}, optional @@ -220,7 +220,7 @@ def transpose( # Validate dimensions x = as_xtensor(x) x_dims = x.type.dims - invalid_dims = set(dims) - {..., *x_dims} + invalid_dims = set(dim) - {..., *x_dims} if invalid_dims: if missing_dims != "ignore": msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {x_dims}" @@ -229,21 +229,27 @@ def transpose( else: warnings.warn(msg) # Handle missing dimensions if not raising - dims = tuple(d for d in dims if d in x_dims or d is ...) - - if dims == () or dims == (...,): - dims = tuple(reversed(x_dims)) - elif ... in dims: - if dims.count(...) > 1: + dim = tuple(d for d in dim if d in x_dims or d is ...) + + if dim == (): + dim = tuple(reversed(x_dims)) + elif dim == (...,): + dim = x_dims + elif ... in dim: + if dim.count(...) > 1: raise ValueError("Ellipsis (...) can only appear once in the dimensions") # Handle ellipsis expansion - ellipsis_idx = dims.index(...) - pre = dims[:ellipsis_idx] - post = dims[ellipsis_idx + 1 :] + ellipsis_idx = dim.index(...) + pre = dim[:ellipsis_idx] + post = dim[ellipsis_idx + 1 :] middle = [d for d in x_dims if d not in pre + post] - dims = (*pre, *middle, *post) + dim = (*pre, *middle, *post) + + if dim == x_dims: + # No-op transpose + return x - return Transpose(typing.cast(tuple[str], dims))(x) + return Transpose(dims=typing.cast(tuple[str], dim))(x) class Concat(XOp): diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 93f2bb5499..94b0eeedfe 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -691,14 +691,14 @@ def diff(self, dim, n=1): # https://docs.xarray.dev/en/latest/api.html#id8 def transpose( self, - *dims: str | EllipsisType, + *dim: str | EllipsisType, missing_dims: Literal["raise", "warn", "ignore"] = "raise", ): """Transpose dimensions of the tensor. Parameters ---------- - *dims : str | Ellipsis + *dim : str | Ellipsis Dimensions to transpose. If empty, performs a full transpose. Can use ellipsis (...) to represent remaining dimensions. missing_dims : {"raise", "warn", "ignore"}, default="raise" @@ -718,7 +718,7 @@ def transpose( If missing_dims="raise" and any dimensions don't exist. If multiple ellipsis are provided. """ - return px.shape.transpose(self, *dims, missing_dims=missing_dims) + return px.shape.transpose(self, *dim, missing_dims=missing_dims) def stack(self, dim, **dims): return px.shape.stack(self, dim, **dims) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index da2c5f1913..6abd7b5103 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -15,7 +15,6 @@ from pytensor.xtensor.shape import ( concat, stack, - transpose, unstack, ) from pytensor.xtensor.type import xtensor @@ -46,13 +45,14 @@ def test_transpose(): permutations = [ (a, b, c, d, e), # identity (e, d, c, b, a), # full tranpose - (), # eqivalent to full transpose + (), # equivalent to full transpose (a, b, c, e, d), # swap last two dims (..., d, c), # equivalent to (a, b, e, d, c) (b, a, ..., e, d), # equivalent to (b, a, c, d, e) (c, a, ...), # equivalent to (c, a, b, d, e) + (...,), # no op ] - outs = [transpose(x, *perm) for perm in permutations] + outs = [x.transpose(*perm) for perm in permutations] fn = xr_function([x], outs) x_test = xr_arange_like(x)