Skip to content

Commit 5988cec

Browse files
committed
Add transpose operation for labeled tensors with ellipsis support
1 parent e32d865 commit 5988cec

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.tensor import broadcast_to, join, moveaxis
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5-
from pytensor.xtensor.shape import Concat, Stack
5+
from pytensor.xtensor.shape import Concat, Stack, Transpose
66

77

88
@register_xcanonicalize
@@ -70,3 +70,35 @@ def lower_concat(fgraph, node):
7070
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
7171
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
7272
return [new_out]
73+
74+
75+
@register_xcanonicalize
76+
@node_rewriter(tracks=[Transpose])
77+
def lower_transpose(fgraph, node):
78+
[x] = node.inputs
79+
# Determine the permutation of axes
80+
out_dims = node.op.dims
81+
in_dims = x.type.dims
82+
# Expand ellipsis if present
83+
if out_dims == () or out_dims == (...,):
84+
expanded_dims = tuple(reversed(in_dims))
85+
elif ... in out_dims:
86+
pre = []
87+
post = []
88+
found = False
89+
for d in out_dims:
90+
if d is ...:
91+
found = True
92+
elif not found:
93+
pre.append(d)
94+
else:
95+
post.append(d)
96+
middle = [d for d in in_dims if d not in pre + post]
97+
expanded_dims = tuple(pre + middle + post)
98+
else:
99+
expanded_dims = out_dims
100+
perm = tuple(in_dims.index(d) for d in expanded_dims)
101+
x_tensor = tensor_from_xtensor(x)
102+
x_tensor_transposed = x_tensor.transpose(perm)
103+
new_out = xtensor_from_tensor(x_tensor_transposed, dims=expanded_dims)
104+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,49 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7373
return y
7474

7575

76+
class Transpose(XOp):
77+
__props__ = ("dims",)
78+
79+
def __init__(self, dims: tuple[str, ...]):
80+
super().__init__()
81+
self.dims = dims
82+
83+
def make_node(self, x):
84+
x = as_xtensor(x)
85+
# Allow ellipsis for full transpose
86+
if self.dims == () or self.dims == (...,):
87+
dims = tuple(reversed(x.type.dims))
88+
else:
89+
# Expand ellipsis if present
90+
if ... in self.dims:
91+
pre = []
92+
post = []
93+
found = False
94+
for d in self.dims:
95+
if d is ...:
96+
found = True
97+
elif not found:
98+
pre.append(d)
99+
else:
100+
post.append(d)
101+
middle = [d for d in x.type.dims if d not in pre + post]
102+
dims = tuple(pre + middle + post)
103+
else:
104+
dims = self.dims
105+
if set(dims) != set(x.type.dims):
106+
raise ValueError(f"Transpose dims {dims} must match {x.type.dims}")
107+
output = xtensor(
108+
dtype=x.type.dtype,
109+
shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims),
110+
dims=dims,
111+
)
112+
return Apply(self, [x], [output])
113+
114+
115+
def transpose(x, *dims):
116+
return Transpose(dims)(x)
117+
118+
76119
class Concat(XOp):
77120
__props__ = ("dim",)
78121

tests/xtensor/test_shape.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from xarray import DataArray
1111
from xarray import concat as xr_concat
1212

13-
from pytensor.xtensor.shape import concat, stack
13+
from pytensor.xtensor.shape import concat, stack, transpose
1414
from pytensor.xtensor.type import xtensor
1515
from tests.xtensor.util import xr_assert_allclose, xr_function, xr_random_like
1616

@@ -24,9 +24,8 @@ def powerset(iterable, min_group_size=0):
2424
)
2525

2626

27-
@pytest.mark.xfail(reason="Not yet implemented")
27+
# @pytest.mark.xfail(reason="Not yet implemented")
2828
def test_transpose():
29-
transpose = None
3029
a, b, c, d, e = "abcde"
3130

3231
x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11))

0 commit comments

Comments
 (0)