5
5
# import jax
6
6
import numpy as np
7
7
import pytest
8
- # import torch
8
+ import torch
9
9
import paddle
10
10
11
11
import array_api_compat
@@ -73,11 +73,11 @@ def test_array_namespace(library, api_version, use_compat):
73
73
"""
74
74
subprocess .run ([sys .executable , "-c" , code ], check = True )
75
75
76
- # def test_jax_zero_gradient():
77
- # jx = jax.numpy.arange(4)
78
- # jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
79
- # assert (array_api_compat.get_namespace(jax_zero) is
80
- # array_api_compat.get_namespace(jx))
76
+ def test_jax_zero_gradient ():
77
+ jx = jax .numpy .arange (4 )
78
+ jax_zero = jax .vmap (jax .grad (jax .numpy .float32 , allow_int = True ))(jx )
79
+ assert (array_api_compat .get_namespace (jax_zero ) is
80
+ array_api_compat .get_namespace (jx ))
81
81
82
82
def test_array_namespace_errors ():
83
83
pytest .raises (TypeError , lambda : array_namespace ([1 ]))
@@ -87,53 +87,32 @@ def test_array_namespace_errors():
87
87
pytest .raises (TypeError , lambda : array_namespace ((x , x )))
88
88
pytest .raises (TypeError , lambda : array_namespace (x , (x , x )))
89
89
90
- # def test_array_namespace_errors_torch():
91
- # y = torch.asarray([1, 2])
92
- # x = np.asarray([1, 2])
93
- # pytest.raises(TypeError, lambda: array_namespace(x, y))
90
+ def test_array_namespace_errors_torch ():
91
+ y = torch .asarray ([1 , 2 ])
92
+ x = np .asarray ([1 , 2 ])
93
+ pytest .raises (TypeError , lambda : array_namespace (x , y ))
94
94
95
95
96
96
def test_array_namespace_errors_paddle ():
97
97
y = paddle .to_tensor ([1 , 2 ])
98
98
x = np .asarray ([1 , 2 ])
99
99
pytest .raises (TypeError , lambda : array_namespace (x , y ))
100
100
101
-
102
- # def test_api_version():
103
- # x = torch.asarray([1, 2])
104
- # torch_ = import_("torch", wrapper=True)
105
- # assert array_namespace(x, api_version="2023.12") == torch_
106
- # assert array_namespace(x, api_version=None) == torch_
107
- # assert array_namespace(x) == torch_
108
- # # Should issue a warning
109
- # with warnings.catch_warnings(record=True) as w:
110
- # assert array_namespace(x, api_version="2021.12") == torch_
111
- # assert len(w) == 1
112
- # assert "2021.12" in str(w[0].message)
113
-
114
- # # Should issue a warning
115
- # with warnings.catch_warnings(record=True) as w:
116
- # assert array_namespace(x, api_version="2022.12") == torch_
117
- # assert len(w) == 1
118
- # assert "2022.12" in str(w[0].message)
119
-
120
- # pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
121
-
122
101
def test_api_version ():
123
- x = paddle .asarray ([1 , 2 ])
124
- paddle_ = import_ ("paddle " , wrapper = True )
125
- assert array_namespace (x , api_version = "2023.12" ) == paddle_
126
- assert array_namespace (x , api_version = None ) == paddle_
127
- assert array_namespace (x ) == paddle_
102
+ x = torch .asarray ([1 , 2 ])
103
+ torch_ = import_ ("torch " , wrapper = True )
104
+ assert array_namespace (x , api_version = "2023.12" ) == torch_
105
+ assert array_namespace (x , api_version = None ) == torch_
106
+ assert array_namespace (x ) == torch_
128
107
# Should issue a warning
129
108
with warnings .catch_warnings (record = True ) as w :
130
- assert array_namespace (x , api_version = "2021.12" ) == paddle_
109
+ assert array_namespace (x , api_version = "2021.12" ) == torch_
131
110
assert len (w ) == 1
132
111
assert "2021.12" in str (w [0 ].message )
133
112
134
113
# Should issue a warning
135
114
with warnings .catch_warnings (record = True ) as w :
136
- assert array_namespace (x , api_version = "2022.12" ) == paddle_
115
+ assert array_namespace (x , api_version = "2022.12" ) == torch_
137
116
assert len (w ) == 1
138
117
assert "2022.12" in str (w [0 ].message )
139
118
0 commit comments