1+ from typing import Sequence
2+
13import numpy as np
24import pytest
35import xarray as xr
810
911@pytest .mark .parametrize (
1012 "axes" ,
11- ["yx" , "xy" , "cyx" , "yxc" , "bczyx" , "xyz" , "xyzc" , "bzyxc" ],
13+ [
14+ "yx" ,
15+ "xy" ,
16+ "cyx" ,
17+ "yxc" ,
18+ "bczyx" ,
19+ "xyz" ,
20+ "xyzc" ,
21+ "bzyxc" ,
22+ ("batch" , "channel" , "x" , "y" ),
23+ ],
1224)
13- def test_transpose_tensor_2d (axes : str ):
25+ def test_transpose_tensor_2d (axes : Sequence [ str ] ):
1426
1527 tensor = Tensor .from_numpy (np .random .rand (256 , 256 ), dims = None )
1628 transposed = tensor .transpose ([AxisId (a ) for a in axes ])
@@ -19,9 +31,18 @@ def test_transpose_tensor_2d(axes: str):
1931
2032@pytest .mark .parametrize (
2133 "axes" ,
22- ["zyx" , "cyzx" , "yzixc" , "bczyx" , "xyz" , "xyzc" , "bzyxtc" ],
34+ [
35+ "zyx" ,
36+ "cyzx" ,
37+ "yzixc" ,
38+ "bczyx" ,
39+ "xyz" ,
40+ "xyzc" ,
41+ "bzyxtc" ,
42+ ("batch" , "channel" , "x" , "y" , "z" ),
43+ ],
2344)
24- def test_transpose_tensor_3d (axes : str ):
45+ def test_transpose_tensor_3d (axes : Sequence [ str ] ):
2546 tensor = Tensor .from_numpy (np .random .rand (64 , 64 , 64 ), dims = None )
2647 transposed = tensor .transpose ([AxisId (a ) for a in axes ])
2748 assert transposed .ndim == len (axes )
0 commit comments