Skip to content

Commit cd1e5dc

Browse files
committed
Implement stack for XTensorVariables
1 parent 155db9f commit cd1e5dc

File tree

5 files changed

+173
-0
lines changed

5 files changed

+173
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.shape

pytensor/xtensor/rewriting/shape.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import moveaxis
3+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
4+
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
5+
from pytensor.xtensor.shape import Stack
6+
7+
8+
@register_lower_xtensor
9+
@node_rewriter(tracks=[Stack])
10+
def lower_stack(fgraph, node):
11+
[x] = node.inputs
12+
batch_ndim = x.type.ndim - len(node.op.stacked_dims)
13+
stacked_axes = [
14+
i for i, dim in enumerate(x.type.dims) if dim in node.op.stacked_dims
15+
]
16+
end = tuple(range(-len(stacked_axes), 0))
17+
18+
x_tensor = tensor_from_xtensor(x)
19+
x_tensor_transposed = moveaxis(x_tensor, source=stacked_axes, destination=end)
20+
if batch_ndim == (x.type.ndim - 1):
21+
# This happens when we stack a "single" dimension, in this case all we need is the transpose
22+
# Note: If we have meaningful rewrites before lowering, consider canonicalizing this as a Transpose + Rename
23+
final_tensor = x_tensor_transposed
24+
else:
25+
final_shape = (*tuple(x_tensor_transposed.shape)[:batch_ndim], -1)
26+
final_tensor = x_tensor_transposed.reshape(final_shape)
27+
28+
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
29+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from collections.abc import Sequence
2+
3+
from pytensor.graph import Apply
4+
from pytensor.xtensor.basic import XOp
5+
from pytensor.xtensor.type import as_xtensor, xtensor
6+
7+
8+
class Stack(XOp):
9+
__props__ = ("new_dim_name", "stacked_dims")
10+
11+
def __init__(self, new_dim_name: str, stacked_dims: tuple[str, ...]):
12+
super().__init__()
13+
if new_dim_name in stacked_dims:
14+
raise ValueError(
15+
f"Stacking dim {new_dim_name} must not be in {stacked_dims}"
16+
)
17+
if not stacked_dims:
18+
raise ValueError(f"Stacking dims must not be empty: got {stacked_dims}")
19+
self.new_dim_name = new_dim_name
20+
self.stacked_dims = stacked_dims
21+
22+
def make_node(self, x):
23+
x = as_xtensor(x)
24+
if not (set(self.stacked_dims) <= set(x.type.dims)):
25+
raise ValueError(
26+
f"Stacking dims {self.stacked_dims} must be a subset of {x.type.dims}"
27+
)
28+
if self.new_dim_name in x.type.dims:
29+
raise ValueError(
30+
f"Stacking dim {self.new_dim_name} must not be in {x.type.dims}"
31+
)
32+
if len(self.stacked_dims) == x.type.ndim:
33+
batch_dims, batch_shape = (), ()
34+
else:
35+
batch_dims, batch_shape = zip(
36+
*(
37+
(dim, shape)
38+
for dim, shape in zip(x.type.dims, x.type.shape)
39+
if dim not in self.stacked_dims
40+
)
41+
)
42+
stack_shape = 1
43+
for dim, shape in zip(x.type.dims, x.type.shape):
44+
if dim in self.stacked_dims:
45+
if shape is None:
46+
stack_shape = None
47+
break
48+
else:
49+
stack_shape *= shape
50+
output = xtensor(
51+
dtype=x.type.dtype,
52+
shape=(*batch_shape, stack_shape),
53+
dims=(*batch_dims, self.new_dim_name),
54+
)
55+
return Apply(self, [x], [output])
56+
57+
58+
def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
59+
if dim is not None:
60+
if dims:
61+
raise ValueError("Cannot use both positional dim and keyword dims in stack")
62+
dims = dim
63+
64+
y = x
65+
for new_dim_name, stacked_dims in dims.items():
66+
if isinstance(stacked_dims, str):
67+
raise TypeError(
68+
f"Stacking dims must be a sequence of strings, got a single string: {stacked_dims}"
69+
)
70+
y = Stack(new_dim_name, tuple(stacked_dims))(y)
71+
return y

pytensor/xtensor/type.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,11 @@ def sel(self, *args, **kwargs):
311311
def __getitem__(self, idx):
312312
raise NotImplementedError("Indexing not yet implemnented")
313313

314+
# Reshaping and reorganizing
315+
# https://docs.xarray.dev/en/latest/api.html#id8
316+
def stack(self, dim, **dims):
317+
return px.shape.stack(self, dim, **dims)
318+
314319

315320
class XTensorConstantSignature(TensorConstantSignature):
316321
pass

tests/xtensor/test_shape.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# ruff: noqa: E402
2+
import pytest
3+
4+
5+
pytest.importorskip("xarray")
6+
7+
from itertools import chain, combinations
8+
9+
from pytensor.xtensor.shape import stack
10+
from pytensor.xtensor.type import xtensor
11+
from tests.xtensor.util import (
12+
xr_arange_like,
13+
xr_assert_allclose,
14+
xr_function,
15+
)
16+
17+
18+
def powerset(iterable, min_group_size=0):
19+
"Subsequences of the iterable from shortest to longest."
20+
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
21+
s = list(iterable)
22+
return chain.from_iterable(
23+
combinations(s, r) for r in range(min_group_size, len(s) + 1)
24+
)
25+
26+
27+
def test_stack():
28+
dims = ("a", "b", "c", "d")
29+
x = xtensor("x", dims=dims, shape=(2, 3, 5, 7))
30+
outs = [
31+
stack(x, new_dim=dims_to_stack)
32+
for dims_to_stack in powerset(dims, min_group_size=2)
33+
]
34+
35+
fn = xr_function([x], outs)
36+
x_test = xr_arange_like(x)
37+
res = fn(x_test)
38+
39+
expected_res = [
40+
x_test.stack(new_dim=dims_to_stack)
41+
for dims_to_stack in powerset(dims, min_group_size=2)
42+
]
43+
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res):
44+
xr_assert_allclose(res_i, expected_res_i)
45+
46+
47+
def test_stack_single_dim():
48+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 5))
49+
out = stack(x, {"d": ["a"]})
50+
assert out.type.dims == ("b", "c", "d")
51+
52+
fn = xr_function([x], out)
53+
x_test = xr_arange_like(x)
54+
res = fn(x_test)
55+
expected_res = x_test.stack(d=["a"])
56+
xr_assert_allclose(res, expected_res)
57+
58+
59+
def test_multiple_stacks():
60+
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7))
61+
out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d"))
62+
63+
fn = xr_function([x], [out])
64+
x_test = xr_arange_like(x)
65+
res = fn(x_test)
66+
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
67+
xr_assert_allclose(res[0], expected_res)

0 commit comments

Comments
 (0)