Skip to content

Commit 0c6e343

Browse files
committed
Add simple test to unstack
1 parent 5f4253f commit 0c6e343

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/xtensor/test_shape.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,19 @@ def test_unstack():
158158
# the shapes are right but the "other" one has the elements in different order
159159
# I think it is an issue with the test not the function but not sure
160160
# xr_assert_allclose(res_i, expected_res_i)
161+
162+
163+
def test_unstack_simple():
164+
x = xtensor("x", dims=("a", "bc", "d"), shape=(2, 3 * 5, 7))
165+
y = unstack(x, bc=dict(b=3, c=5))
166+
assert y.type.dims == ("a", "d", "b", "c")
167+
assert y.type.shape == (2, 7, 3, 5)
168+
169+
fn = xr_function([x], y)
170+
171+
x_np = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape)
172+
x_test = DataArray(x_np, dims=x.type.dims)
173+
res = fn(x_test)
174+
np.testing.assert_allclose(
175+
res.values, x_np.reshape(2, 3, 5, 7).transpose(0, 3, 1, 2)
176+
)

0 commit comments

Comments
 (0)