Skip to content

Commit 5f4253f

Browse files
committed
first pass at unstack
1 parent f5d426f commit 5f4253f

File tree

3 files changed

+146
-2
lines changed

3 files changed

+146
-2
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.tensor import moveaxis
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5-
from pytensor.xtensor.shape import Stack
5+
from pytensor.xtensor.shape import Stack, UnStack
66

77

88
@register_xcanonicalize
@@ -27,3 +27,19 @@ def lower_stack(fgraph, node):
2727

2828
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
2929
return [new_out]
30+
31+
32+
@register_xcanonicalize
33+
@node_rewriter(tracks=[UnStack])
34+
def lower_unstack(fgraph, node):
35+
[x] = node.inputs
36+
axis_to_unstack = x.type.dims.index(node.op.old_dim_name)
37+
38+
x_tensor = tensor_from_xtensor(x)
39+
x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1])
40+
final_tensor = x_tensor_transposed.reshape(
41+
(*x_tensor_transposed.shape[:-1], *node.op.unstacked_lengths)
42+
)
43+
44+
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
45+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,75 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
6969
)
7070
y = Stack(new_dim_name, tuple(stacked_dims))(y)
7171
return y
72+
73+
74+
class UnStack(XOp):
75+
__props__ = ("old_dim_name", "unstacked_dims", "unstacked_lengths")
76+
77+
def __init__(
78+
self,
79+
old_dim_name: str,
80+
unstacked_dims: tuple[str, ...],
81+
unstacked_lengths: tuple[int, ...],
82+
):
83+
super().__init__()
84+
if old_dim_name in unstacked_dims:
85+
raise ValueError(
86+
f"Dim to be unstacked {old_dim_name} can't be in {unstacked_dims}"
87+
)
88+
if len(unstacked_dims) != len(unstacked_lengths):
89+
raise ValueError(
90+
"Tuples with unstacked dim names and lengths must have the same length "
91+
f"but have {len(unstacked_dims)} and {len(unstacked_lengths)}"
92+
)
93+
if not unstacked_dims:
94+
raise ValueError("Dims to unstack into can't be empty.")
95+
if len(unstacked_dims) == 1:
96+
raise ValueError("Only one dimension to unstack into, use rename instead")
97+
self.old_dim_name = old_dim_name
98+
self.unstacked_dims = unstacked_dims
99+
self.unstacked_lengths = unstacked_lengths
100+
101+
def make_node(self, x):
102+
x = as_xtensor(x)
103+
if self.old_dim_name not in x.type.dims:
104+
raise ValueError(
105+
f"Dim to unstack {self.old_dim_name} must be in {x.type.dims}"
106+
)
107+
if not set(self.unstacked_dims).isdisjoint(x.type.dims):
108+
raise ValueError(
109+
f"Dims to unstack into {self.unstacked_dims} must not be in {x.type.dims}"
110+
)
111+
if x.type.ndim == 1:
112+
batch_dims, batch_shape = (), ()
113+
else:
114+
batch_dims, batch_shape = zip(
115+
*(
116+
(dim, shape)
117+
for dim, shape in zip(x.type.dims, x.type.shape)
118+
if dim != self.old_dim_name
119+
)
120+
)
121+
122+
output = xtensor(
123+
dtype=x.type.dtype,
124+
shape=(*batch_shape, *self.unstacked_lengths),
125+
dims=(*batch_dims, *self.unstacked_dims),
126+
)
127+
return Apply(self, [x], [output])
128+
129+
130+
def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
131+
if dim is not None:
132+
if dims:
133+
raise ValueError(
134+
"Cannot use both positional dim and keyword dims in unstack"
135+
)
136+
dims = dim
137+
138+
y = x
139+
for old_dim_name, unstacked_dict in dims.items():
140+
y = UnStack(
141+
old_dim_name, tuple(unstacked_dict.keys()), tuple(unstacked_dict.values())
142+
)(y)
143+
return y

tests/xtensor/test_shape.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
from xarray import DataArray
1111

12-
from pytensor.xtensor.shape import stack
12+
from pytensor.xtensor.shape import stack, unstack
1313
from pytensor.xtensor.type import xtensor
1414
from tests.xtensor.util import xr_assert_allclose, xr_function
1515

@@ -102,3 +102,59 @@ def test_multiple_stacks():
102102
res = fn(x_test)
103103
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
104104
xr_assert_allclose(res[0], expected_res)
105+
106+
107+
def test_unstack():
108+
unstacked_dims = {"a": 2, "b": 3, "c": 5, "d": 7}
109+
dims = ("abcd",)
110+
x = xtensor("x", dims=dims, shape=(2 * 3 * 5 * 7,))
111+
outs = [
112+
unstack(
113+
x,
114+
abcd=(
115+
{d: l for d, l in unstacked_dims.items() if d in dims_to_unstack}
116+
| (
117+
{}
118+
if set(dims_to_unstack) == set(unstacked_dims)
119+
else {
120+
"other": int(
121+
np.prod(
122+
[
123+
l
124+
for d, l in unstacked_dims.items()
125+
if d not in dims_to_unstack
126+
]
127+
)
128+
)
129+
}
130+
)
131+
),
132+
)
133+
for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2)
134+
]
135+
fn = xr_function([x], outs)
136+
# we test through the complementary operation in xarray to avoid needing coords
137+
# which are required for unstack. We end up with a subset of {a, b, c, d} and
138+
# other after unstacking, so we create the fully unstacked dataarray
139+
# and stack to create this extra "other" dimension as needed
140+
x_test = DataArray(
141+
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(
142+
list(unstacked_dims.values())
143+
),
144+
dims=list(unstacked_dims.keys()),
145+
)
146+
res = fn(x_test)
147+
148+
expected_res = [
149+
x_test.stack(
150+
{}
151+
if set(dims_to_unstack) == set(unstacked_dims)
152+
else {"other": [d for d in unstacked_dims if d not in dims_to_unstack]}
153+
)
154+
for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2)
155+
]
156+
for res_i, expected_res_i in zip(res, expected_res):
157+
assert res_i.shape == expected_res_i.shape
158+
# the shapes are right but the "other" one has the elements in different order
159+
# I think it is an issue with the test not the function but not sure
160+
# xr_assert_allclose(res_i, expected_res_i)

0 commit comments

Comments
 (0)