Skip to content

Commit 382fa8d

Browse files
committed
Tweak transpose
1 parent de759b1 commit 382fa8d

File tree

3 files changed

+111
-144
lines changed

3 files changed

+111
-144
lines changed

pytensor/xtensor/shape.py

Lines changed: 47 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.graph import Apply
77
from pytensor.scalar import upcast
88
from pytensor.xtensor.basic import XOp
9-
from pytensor.xtensor.type import as_xtensor, xtensor
9+
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
1010

1111

1212
class Stack(XOp):
@@ -75,102 +75,55 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7575
return y
7676

7777

78-
def expand_ellipsis(
79-
dims: tuple[str, ...],
80-
all_dims: tuple[str, ...],
81-
validate: bool = True,
82-
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
83-
) -> tuple[str, ...]:
84-
"""Expand ellipsis in dimension permutation.
85-
86-
Parameters
87-
----------
88-
dims : tuple[str, ...]
89-
The dimension permutation, which may contain ellipsis
90-
all_dims : tuple[str, ...]
91-
All available dimensions
92-
validate : bool, default True
93-
Whether to check that all non-ellipsis elements in dims are valid dimension names.
94-
missing_dims : {"raise", "warn", "ignore"}, optional
95-
How to handle dimensions that don't exist in all_dims:
96-
- "raise": Raise an error if any dimensions don't exist (default)
97-
- "warn": Warn if any dimensions don't exist
98-
- "ignore": Silently ignore any dimensions that don't exist
99-
100-
Returns
101-
-------
102-
tuple[str, ...]
103-
The expanded dimension permutation
104-
105-
Raises
106-
------
107-
ValueError
108-
If more than one ellipsis is present in dims.
109-
If any non-ellipsis element in dims is not a valid dimension name and validate is True.
110-
If missing_dims is "raise" and any dimension in dims doesn't exist in all_dims.
111-
"""
112-
# Handle empty or full ellipsis case
113-
if dims == () or dims == (...,):
114-
return tuple(reversed(all_dims))
115-
116-
# Check for multiple ellipses
117-
if dims.count(...) > 1:
118-
raise ValueError("an index can only have a single ellipsis ('...')")
119-
120-
# Validate dimensions if requested
121-
if validate:
122-
invalid_dims = set(dims) - {..., *all_dims}
123-
if invalid_dims:
124-
if missing_dims == "raise":
125-
raise ValueError(
126-
f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}"
127-
)
128-
elif missing_dims == "warn":
129-
warnings.warn(f"Dimensions {invalid_dims} do not exist in {all_dims}")
130-
131-
# Handle missing dimensions if not raising
132-
if missing_dims in ("ignore", "warn"):
133-
dims = tuple(d for d in dims if d in all_dims or d is ...)
134-
135-
# If no ellipsis, just return the dimensions
136-
if ... not in dims:
137-
return dims
138-
139-
# Handle ellipsis expansion
140-
ellipsis_idx = dims.index(...)
141-
pre = list(dims[:ellipsis_idx])
142-
post = list(dims[ellipsis_idx + 1 :])
143-
middle = [d for d in all_dims if d not in pre + post]
144-
return tuple(pre + middle + post)
145-
146-
14778
class Transpose(XOp):
148-
__props__ = ("dims", "missing_dims")
79+
__props__ = ("dims",)
14980

15081
def __init__(
15182
self,
15283
dims: tuple[str | Literal[...], ...],
153-
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
15484
):
15585
super().__init__()
86+
if dims.count(...) > 1:
87+
raise ValueError("an index can only have a single ellipsis ('...')")
15688
self.dims = dims
157-
self.missing_dims = missing_dims
15889

15990
def make_node(self, x):
16091
x = as_xtensor(x)
161-
dims = expand_ellipsis(
162-
self.dims, x.type.dims, validate=True, missing_dims=self.missing_dims
163-
)
92+
93+
transpose_dims = self.dims
94+
x_dims = x.type.dims
95+
96+
if transpose_dims == () or transpose_dims == (...,):
97+
out_dims = tuple(reversed(x_dims))
98+
elif ... in transpose_dims:
99+
# Handle ellipsis expansion
100+
ellipsis_idx = transpose_dims.index(...)
101+
pre = transpose_dims[:ellipsis_idx]
102+
post = transpose_dims[ellipsis_idx + 1 :]
103+
middle = [d for d in x_dims if d not in pre + post]
104+
out_dims = (*pre, *middle, *post)
105+
if set(out_dims) != set(x_dims):
106+
raise ValueError(f"{out_dims} must be a permuted list of {x_dims}")
107+
else:
108+
out_dims = transpose_dims
109+
if set(out_dims) != set(x_dims):
110+
raise ValueError(
111+
f"{out_dims} must be a permuted list of {x_dims}, unless `...` is included"
112+
)
164113

165114
output = xtensor(
166115
dtype=x.type.dtype,
167-
shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims),
168-
dims=dims,
116+
shape=tuple(x.type.shape[x.type.dims.index(d)] for d in out_dims),
117+
dims=out_dims,
169118
)
170119
return Apply(self, [x], [output])
171120

172121

173-
def transpose(x, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise"):
122+
def transpose(
123+
x,
124+
*dims: str | Literal[...],
125+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
126+
) -> XTensorVariable:
174127
"""Transpose dimensions of the tensor.
175128
176129
Parameters
@@ -196,7 +149,21 @@ def transpose(x, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "rais
196149
ValueError
197150
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
198151
"""
199-
return Transpose(dims, missing_dims=missing_dims)(x)
152+
# Validate dimensions
153+
x = as_xtensor(x)
154+
all_dims = x.type.dims
155+
invalid_dims = set(dims) - {..., *all_dims}
156+
if invalid_dims:
157+
if missing_dims != "ignore":
158+
msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {all_dims}"
159+
if missing_dims == "raise":
160+
raise ValueError(msg)
161+
else:
162+
warnings.warn(msg)
163+
# Handle missing dimensions if not raising
164+
dims = tuple(d for d in dims if d in all_dims or d is ...)
165+
166+
return Transpose(dims)(x)
200167

201168

202169
class Concat(XOp):

pytensor/xtensor/type.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
XARRAY_AVAILABLE = False
1111

1212
from collections.abc import Sequence
13-
from typing import TypeVar, Union
13+
from typing import Literal, TypeVar
1414

1515
import numpy as np
1616

@@ -357,7 +357,11 @@ def imag(self):
357357
def real(self):
358358
return px.math.real(self)
359359

360-
def transpose(self, *dims: Union[str, type(Ellipsis)], missing_dims: Literal["raise", "warn", "ignore"] = "raise"):
360+
def transpose(
361+
self,
362+
*dims: str | Literal[...],
363+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
364+
) -> "XTensorVariable":
361365
"""Transpose dimensions of the tensor.
362366
363367
Parameters

tests/xtensor/test_shape.py

Lines changed: 58 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# ruff: noqa: E402
2+
import re
3+
24
import pytest
35

46

@@ -51,6 +53,62 @@ def test_transpose():
5153
xr_assert_allclose(res_i, expected_res_i)
5254

5355

56+
def test_xtensor_variable_transpose():
57+
"""Test the transpose() method of XTensorVariable."""
58+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
59+
60+
# Test basic transpose
61+
out = x.transpose()
62+
fn = xr_function([x], out)
63+
x_test = xr_arange_like(x)
64+
xr_assert_allclose(fn(x_test), x_test.transpose())
65+
66+
# Test transpose with specific dimensions
67+
out = x.transpose("c", "a", "b")
68+
fn = xr_function([x], out)
69+
xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b"))
70+
71+
# Test transpose with ellipsis
72+
out = x.transpose("c", ...)
73+
fn = xr_function([x], out)
74+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
75+
76+
# Test error cases
77+
with pytest.raises(
78+
ValueError,
79+
match=re.escape(
80+
"Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')"
81+
),
82+
):
83+
x.transpose("d")
84+
85+
with pytest.raises(ValueError, match="an index can only have a single ellipsis"):
86+
x.transpose("a", ..., "b", ...)
87+
88+
# Test missing_dims parameter
89+
# Test ignore
90+
out = x.transpose("c", ..., "d", missing_dims="ignore")
91+
fn = xr_function([x], out)
92+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
93+
94+
# Test warn
95+
with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"):
96+
out = x.transpose("c", ..., "d", missing_dims="warn")
97+
fn = xr_function([x], out)
98+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
99+
100+
101+
def test_xtensor_variable_T():
102+
"""Test the T property of XTensorVariable."""
103+
# Test T property with 3D tensor
104+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
105+
out = x.T
106+
107+
fn = xr_function([x], out)
108+
x_test = xr_arange_like(x)
109+
xr_assert_allclose(fn(x_test), x_test.T)
110+
111+
54112
def test_stack():
55113
dims = ("a", "b", "c", "d")
56114
x = xtensor("x", dims=dims, shape=(2, 3, 5, 7))
@@ -153,65 +211,3 @@ def test_concat_scalar():
153211
res = fn(x1_test, x2_test)
154212
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
155213
xr_assert_allclose(res, expected_res)
156-
157-
158-
def test_xtensor_variable_transpose():
159-
"""Test the transpose() method of XTensorVariable."""
160-
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
161-
162-
# Test basic transpose
163-
out = x.transpose()
164-
fn = xr_function([x], out)
165-
x_test = xr_arange_like(x)
166-
xr_assert_allclose(fn(x_test), x_test.transpose())
167-
168-
# Test transpose with specific dimensions
169-
out = x.transpose("c", "a", "b")
170-
fn = xr_function([x], out)
171-
xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b"))
172-
173-
# Test transpose with ellipsis
174-
out = x.transpose("c", ...)
175-
fn = xr_function([x], out)
176-
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
177-
178-
# Test error cases
179-
with pytest.raises(
180-
ValueError,
181-
match="Invalid dimensions: {'d'}. Available dimensions: \\('a', 'b', 'c'\\)",
182-
):
183-
x.transpose("d")
184-
185-
with pytest.raises(ValueError, match="an index can only have a single ellipsis"):
186-
x.transpose("a", ..., "b", ...)
187-
188-
# Test missing_dims parameter
189-
# Test ignore
190-
out = x.transpose("c", ..., "d", missing_dims="ignore")
191-
fn = xr_function([x], out)
192-
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
193-
194-
# Test warn
195-
with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"):
196-
out = x.transpose("c", ..., "d", missing_dims="warn")
197-
fn = xr_function([x], out)
198-
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
199-
200-
201-
def test_xtensor_variable_T():
202-
"""Test the T property of XTensorVariable."""
203-
# Test T property with 3D tensor
204-
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
205-
out = x.T
206-
207-
fn = xr_function([x], out)
208-
x_test = xr_arange_like(x)
209-
xr_assert_allclose(fn(x_test), x_test.transpose())
210-
211-
# Test T property with 2D tensor
212-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
213-
out = x.T
214-
215-
fn = xr_function([x], out)
216-
x_test = xr_arange_like(x)
217-
xr_assert_allclose(fn(x_test), x_test.transpose())

0 commit comments

Comments
 (0)