Skip to content

Commit de759b1

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement transpose for XTensorVariables
1 parent 56bd328 commit de759b1

File tree

4 files changed

+247
-5
lines changed

4 files changed

+247
-5
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.tensor import broadcast_to, join, moveaxis
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5-
from pytensor.xtensor.shape import Concat, Stack
5+
from pytensor.xtensor.shape import Concat, Stack, Transpose
66

77

88
@register_xcanonicalize
@@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
7070
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
7171
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
7272
return [new_out]
73+
74+
75+
@register_xcanonicalize
76+
@node_rewriter(tracks=[Transpose])
77+
def lower_transpose(fgraph, node):
78+
[x] = node.inputs
79+
# Use the final dimensions that were already computed in make_node
80+
out_dims = node.outputs[0].type.dims
81+
in_dims = x.type.dims
82+
83+
# Compute the permutation based on the final dimensions
84+
perm = tuple(in_dims.index(d) for d in out_dims)
85+
x_tensor = tensor_from_xtensor(x)
86+
x_tensor_transposed = x_tensor.transpose(perm)
87+
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
88+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import warnings
12
from collections.abc import Sequence
3+
from typing import Literal
24

35
from pytensor import Variable
46
from pytensor.graph import Apply
@@ -73,6 +75,130 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7375
return y
7476

7577

78+
def expand_ellipsis(
79+
dims: tuple[str, ...],
80+
all_dims: tuple[str, ...],
81+
validate: bool = True,
82+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
83+
) -> tuple[str, ...]:
84+
"""Expand ellipsis in dimension permutation.
85+
86+
Parameters
87+
----------
88+
dims : tuple[str, ...]
89+
The dimension permutation, which may contain ellipsis
90+
all_dims : tuple[str, ...]
91+
All available dimensions
92+
validate : bool, default True
93+
Whether to check that all non-ellipsis elements in dims are valid dimension names.
94+
missing_dims : {"raise", "warn", "ignore"}, optional
95+
How to handle dimensions that don't exist in all_dims:
96+
- "raise": Raise an error if any dimensions don't exist (default)
97+
- "warn": Warn if any dimensions don't exist
98+
- "ignore": Silently ignore any dimensions that don't exist
99+
100+
Returns
101+
-------
102+
tuple[str, ...]
103+
The expanded dimension permutation
104+
105+
Raises
106+
------
107+
ValueError
108+
If more than one ellipsis is present in dims.
109+
If any non-ellipsis element in dims is not a valid dimension name and validate is True.
110+
If missing_dims is "raise" and any dimension in dims doesn't exist in all_dims.
111+
"""
112+
# Handle empty or full ellipsis case
113+
if dims == () or dims == (...,):
114+
return tuple(reversed(all_dims))
115+
116+
# Check for multiple ellipses
117+
if dims.count(...) > 1:
118+
raise ValueError("an index can only have a single ellipsis ('...')")
119+
120+
# Validate dimensions if requested
121+
if validate:
122+
invalid_dims = set(dims) - {..., *all_dims}
123+
if invalid_dims:
124+
if missing_dims == "raise":
125+
raise ValueError(
126+
f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}"
127+
)
128+
elif missing_dims == "warn":
129+
warnings.warn(f"Dimensions {invalid_dims} do not exist in {all_dims}")
130+
131+
# Handle missing dimensions if not raising
132+
if missing_dims in ("ignore", "warn"):
133+
dims = tuple(d for d in dims if d in all_dims or d is ...)
134+
135+
# If no ellipsis, just return the dimensions
136+
if ... not in dims:
137+
return dims
138+
139+
# Handle ellipsis expansion
140+
ellipsis_idx = dims.index(...)
141+
pre = list(dims[:ellipsis_idx])
142+
post = list(dims[ellipsis_idx + 1 :])
143+
middle = [d for d in all_dims if d not in pre + post]
144+
return tuple(pre + middle + post)
145+
146+
147+
class Transpose(XOp):
148+
__props__ = ("dims", "missing_dims")
149+
150+
def __init__(
151+
self,
152+
dims: tuple[str | Literal[...], ...],
153+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
154+
):
155+
super().__init__()
156+
self.dims = dims
157+
self.missing_dims = missing_dims
158+
159+
def make_node(self, x):
160+
x = as_xtensor(x)
161+
dims = expand_ellipsis(
162+
self.dims, x.type.dims, validate=True, missing_dims=self.missing_dims
163+
)
164+
165+
output = xtensor(
166+
dtype=x.type.dtype,
167+
shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims),
168+
dims=dims,
169+
)
170+
return Apply(self, [x], [output])
171+
172+
173+
def transpose(x, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise"):
174+
"""Transpose dimensions of the tensor.
175+
176+
Parameters
177+
----------
178+
x : XTensorVariable
179+
Input tensor to transpose.
180+
*dims : str
181+
Dimensions to transpose to. Can include ellipsis (...) to represent
182+
remaining dimensions in their original order.
183+
missing_dims : {"raise", "warn", "ignore"}, optional
184+
How to handle dimensions that don't exist in the input tensor:
185+
- "raise": Raise an error if any dimensions don't exist (default)
186+
- "warn": Warn if any dimensions don't exist
187+
- "ignore": Silently ignore any dimensions that don't exist
188+
189+
Returns
190+
-------
191+
XTensorVariable
192+
Transposed tensor with reordered dimensions.
193+
194+
Raises
195+
------
196+
ValueError
197+
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
198+
"""
199+
return Transpose(dims, missing_dims=missing_dims)(x)
200+
201+
76202
class Concat(XOp):
77203
__props__ = ("dim",)
78204

pytensor/xtensor/type.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
XARRAY_AVAILABLE = False
1111

1212
from collections.abc import Sequence
13-
from typing import TypeVar
13+
from typing import TypeVar, Union
1414

1515
import numpy as np
1616

@@ -357,6 +357,46 @@ def imag(self):
357357
def real(self):
358358
return px.math.real(self)
359359

360+
def transpose(self, *dims: Union[str, type(Ellipsis)], missing_dims: Literal["raise", "warn", "ignore"] = "raise"):
361+
"""Transpose dimensions of the tensor.
362+
363+
Parameters
364+
----------
365+
*dims : str | Ellipsis
366+
Dimensions to transpose. If empty, performs a full transpose.
367+
Can use ellipsis (...) to represent remaining dimensions.
368+
missing_dims : {"raise", "warn", "ignore"}, default="raise"
369+
How to handle dimensions that don't exist in the tensor:
370+
- "raise": Raise an error if any dimensions don't exist
371+
- "warn": Warn if any dimensions don't exist
372+
- "ignore": Silently ignore any dimensions that don't exist
373+
374+
Returns
375+
-------
376+
XTensorVariable
377+
Transposed tensor with reordered dimensions.
378+
379+
Raises
380+
------
381+
ValueError
382+
If missing_dims="raise" and any dimensions don't exist.
383+
If multiple ellipsis are provided.
384+
"""
385+
return px.shape.transpose(self, *dims, missing_dims=missing_dims)
386+
387+
@property
388+
def T(self) -> "XTensorVariable":
389+
"""Return the full transpose of the tensor.
390+
391+
This is equivalent to calling transpose() with no arguments.
392+
393+
Returns
394+
-------
395+
XTensorVariable
396+
Fully transposed tensor.
397+
"""
398+
return self.transpose()
399+
360400
# Aggregation
361401
# https://docs.xarray.dev/en/latest/api.html#id6
362402
def all(self, dim):

tests/xtensor/test_shape.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
from xarray import concat as xr_concat
1111

12-
from pytensor.xtensor.shape import concat, stack
12+
from pytensor.xtensor.shape import concat, stack, transpose
1313
from pytensor.xtensor.type import xtensor
1414
from tests.xtensor.util import (
1515
xr_arange_like,
@@ -28,9 +28,7 @@ def powerset(iterable, min_group_size=0):
2828
)
2929

3030

31-
@pytest.mark.xfail(reason="Not yet implemented")
3231
def test_transpose():
33-
transpose = None
3432
a, b, c, d, e = "abcde"
3533

3634
x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11))
@@ -155,3 +153,65 @@ def test_concat_scalar():
155153
res = fn(x1_test, x2_test)
156154
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
157155
xr_assert_allclose(res, expected_res)
156+
157+
158+
def test_xtensor_variable_transpose():
159+
"""Test the transpose() method of XTensorVariable."""
160+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
161+
162+
# Test basic transpose
163+
out = x.transpose()
164+
fn = xr_function([x], out)
165+
x_test = xr_arange_like(x)
166+
xr_assert_allclose(fn(x_test), x_test.transpose())
167+
168+
# Test transpose with specific dimensions
169+
out = x.transpose("c", "a", "b")
170+
fn = xr_function([x], out)
171+
xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b"))
172+
173+
# Test transpose with ellipsis
174+
out = x.transpose("c", ...)
175+
fn = xr_function([x], out)
176+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
177+
178+
# Test error cases
179+
with pytest.raises(
180+
ValueError,
181+
match="Invalid dimensions: {'d'}. Available dimensions: \\('a', 'b', 'c'\\)",
182+
):
183+
x.transpose("d")
184+
185+
with pytest.raises(ValueError, match="an index can only have a single ellipsis"):
186+
x.transpose("a", ..., "b", ...)
187+
188+
# Test missing_dims parameter
189+
# Test ignore
190+
out = x.transpose("c", ..., "d", missing_dims="ignore")
191+
fn = xr_function([x], out)
192+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
193+
194+
# Test warn
195+
with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"):
196+
out = x.transpose("c", ..., "d", missing_dims="warn")
197+
fn = xr_function([x], out)
198+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
199+
200+
201+
def test_xtensor_variable_T():
202+
"""Test the T property of XTensorVariable."""
203+
# Test T property with 3D tensor
204+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
205+
out = x.T
206+
207+
fn = xr_function([x], out)
208+
x_test = xr_arange_like(x)
209+
xr_assert_allclose(fn(x_test), x_test.transpose())
210+
211+
# Test T property with 2D tensor
212+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
213+
out = x.T
214+
215+
fn = xr_function([x], out)
216+
x_test = xr_arange_like(x)
217+
xr_assert_allclose(fn(x_test), x_test.transpose())

0 commit comments

Comments
 (0)