Skip to content

Commit 27ea9aa

Browse files
committed
add test cases
1 parent 84f24fe commit 27ea9aa

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

tests/test_tensor.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Sequence
2+
13
import numpy as np
24
import pytest
35
import xarray as xr
@@ -8,9 +10,19 @@
810

911
@pytest.mark.parametrize(
1012
"axes",
11-
["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"],
13+
[
14+
"yx",
15+
"xy",
16+
"cyx",
17+
"yxc",
18+
"bczyx",
19+
"xyz",
20+
"xyzc",
21+
"bzyxc",
22+
("batch", "channel", "x", "y"),
23+
],
1224
)
13-
def test_transpose_tensor_2d(axes: str):
25+
def test_transpose_tensor_2d(axes: Sequence[str]):
1426

1527
tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None)
1628
transposed = tensor.transpose([AxisId(a) for a in axes])
@@ -19,9 +31,18 @@ def test_transpose_tensor_2d(axes: str):
1931

2032
@pytest.mark.parametrize(
2133
"axes",
22-
["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"],
34+
[
35+
"zyx",
36+
"cyzx",
37+
"yzixc",
38+
"bczyx",
39+
"xyz",
40+
"xyzc",
41+
"bzyxtc",
42+
("batch", "channel", "x", "y", "z"),
43+
],
2344
)
24-
def test_transpose_tensor_3d(axes: str):
45+
def test_transpose_tensor_3d(axes: Sequence[str]):
2546
tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None)
2647
transposed = tensor.transpose([AxisId(a) for a in axes])
2748
assert transposed.ndim == len(axes)

0 commit comments

Comments
 (0)