4
4
5
5
pytest .importorskip ("xarray" )
6
6
7
+ import re
7
8
from itertools import chain , combinations
8
9
9
10
import numpy as np
10
11
from xarray import concat as xr_concat
11
12
12
- from pytensor .xtensor .shape import concat , stack
13
+ from pytensor .xtensor .shape import concat , stack , transpose
13
14
from pytensor .xtensor .type import xtensor
14
15
from tests .xtensor .util import (
15
16
xr_arange_like ,
@@ -28,6 +29,88 @@ def powerset(iterable, min_group_size=0):
28
29
)
29
30
30
31
32
+ def test_transpose ():
33
+ a , b , c , d , e = "abcde"
34
+
35
+ x = xtensor ("x" , dims = (a , b , c , d , e ), shape = (2 , 3 , 5 , 7 , 11 ))
36
+ permutations = [
37
+ (a , b , c , d , e ), # identity
38
+ (e , d , c , b , a ), # full tranpose
39
+ (), # eqivalent to full transpose
40
+ (a , b , c , e , d ), # swap last two dims
41
+ (..., d , c ), # equivalent to (a, b, e, d, c)
42
+ (b , a , ..., e , d ), # equivalent to (b, a, c, d, e)
43
+ (c , a , ...), # equivalent to (c, a, b, d, e)
44
+ ]
45
+ outs = [transpose (x , * perm ) for perm in permutations ]
46
+
47
+ fn = xr_function ([x ], outs )
48
+ x_test = xr_arange_like (x )
49
+ res = fn (x_test )
50
+ expected_res = [x_test .transpose (* perm ) for perm in permutations ]
51
+ for outs_i , res_i , expected_res_i in zip (outs , res , expected_res ):
52
+ xr_assert_allclose (res_i , expected_res_i )
53
+
54
+
55
+ def test_xtensor_variable_transpose ():
56
+ """Test the transpose() method of XTensorVariable."""
57
+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
58
+
59
+ # Test basic transpose
60
+ out = x .transpose ()
61
+ fn = xr_function ([x ], out )
62
+ x_test = xr_arange_like (x )
63
+ xr_assert_allclose (fn (x_test ), x_test .transpose ())
64
+
65
+ # Test transpose with specific dimensions
66
+ out = x .transpose ("c" , "a" , "b" )
67
+ fn = xr_function ([x ], out )
68
+ xr_assert_allclose (fn (x_test ), x_test .transpose ("c" , "a" , "b" ))
69
+
70
+ # Test transpose with ellipsis
71
+ out = x .transpose ("c" , ...)
72
+ fn = xr_function ([x ], out )
73
+ xr_assert_allclose (fn (x_test ), x_test .transpose ("c" , ...))
74
+
75
+ # Test error cases
76
+ with pytest .raises (
77
+ ValueError ,
78
+ match = re .escape (
79
+ "Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')"
80
+ ),
81
+ ):
82
+ x .transpose ("d" )
83
+
84
+ with pytest .raises (
85
+ ValueError ,
86
+ match = re .escape ("Ellipsis (...) can only appear once in the dimensions" ),
87
+ ):
88
+ x .transpose ("a" , ..., "b" , ...)
89
+
90
+ # Test missing_dims parameter
91
+ # Test ignore
92
+ out = x .transpose ("c" , ..., "d" , missing_dims = "ignore" )
93
+ fn = xr_function ([x ], out )
94
+ xr_assert_allclose (fn (x_test ), x_test .transpose ("c" , ...))
95
+
96
+ # Test warn
97
+ with pytest .warns (UserWarning , match = "Dimensions {'d'} do not exist" ):
98
+ out = x .transpose ("c" , ..., "d" , missing_dims = "warn" )
99
+ fn = xr_function ([x ], out )
100
+ xr_assert_allclose (fn (x_test ), x_test .transpose ("c" , ...))
101
+
102
+
103
+ def test_xtensor_variable_T ():
104
+ """Test the T property of XTensorVariable."""
105
+ # Test T property with 3D tensor
106
+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
107
+ out = x .T
108
+
109
+ fn = xr_function ([x ], out )
110
+ x_test = xr_arange_like (x )
111
+ xr_assert_allclose (fn (x_test ), x_test .T )
112
+
113
+
31
114
def test_stack ():
32
115
dims = ("a" , "b" , "c" , "d" )
33
116
x = xtensor ("x" , dims = dims , shape = (2 , 3 , 5 , 7 ))
0 commit comments