Skip to content

Commit 580f7dd

Browse files
committed
Update linalg_tests/test_solve.py
1 parent 7d76d34 commit 580f7dd

File tree

1 file changed

+41
-13
lines changed

1 file changed

+41
-13
lines changed

dpnp/tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
@testing.parameterize(
1616
*testing.product(
1717
{
18+
# "batched_gesv_limit": [None, 0],
1819
"order": ["C", "F"],
1920
}
2021
)
2122
)
2223
@testing.fix_random()
2324
class TestSolve(unittest.TestCase):
24-
# TODO: add get_batched_gesv_limit
25+
2526
# def setUp(self):
2627
# if self.batched_gesv_limit is not None:
2728
# self.old_limit = get_batched_gesv_limit()
@@ -32,6 +33,7 @@ class TestSolve(unittest.TestCase):
3233
# set_batched_gesv_limit(self.old_limit)
3334

3435
@testing.for_dtypes("ifdFD")
36+
# TODO(kataoka): Fix contiguity
3537
@testing.numpy_cupy_allclose(
3638
atol=1e-3, contiguous_check=False, type_check=has_support_aspect64()
3739
)
@@ -71,6 +73,7 @@ def check_shape(self, a_shape, b_shape, error_types):
7173
# NumPy with OpenBLAS returns an empty array
7274
# while numpy with OneMKL raises LinAlgError
7375
@pytest.mark.skip("Undefined behavior")
76+
@testing.numpy_cupy_allclose()
7477
def test_solve_singular_empty(self, xp):
7578
a = xp.zeros((3, 3)) # singular
7679
b = xp.empty((3, 0)) # nrhs = 0
@@ -94,10 +97,33 @@ def test_invalid_shape(self):
9497
self.check_shape((3, 3, 4), (3,), linalg_errors)
9598
self.check_shape((3, 3), (0,), value_errors)
9699
self.check_shape((0, 3, 4), (3,), linalg_errors)
97-
# Not allowed since numpy 2.0
98-
self.check_shape((0, 2, 2), (0, 2), value_errors)
99-
self.check_shape((2, 4, 4), (2, 4), value_errors)
100-
self.check_shape((2, 3, 2, 2), (2, 3, 2), value_errors)
100+
self.check_shape((3, 3), (), value_errors)
101+
# Not allowed since numpy 2
102+
self.check_shape(
103+
(0, 2, 2),
104+
(
105+
0,
106+
2,
107+
),
108+
value_errors,
109+
)
110+
self.check_shape(
111+
(2, 4, 4),
112+
(
113+
2,
114+
4,
115+
),
116+
value_errors,
117+
)
118+
self.check_shape(
119+
(2, 3, 2, 2),
120+
(
121+
2,
122+
3,
123+
2,
124+
),
125+
value_errors,
126+
)
101127

102128

103129
@testing.parameterize(
@@ -110,6 +136,7 @@ def test_invalid_shape(self):
110136
)
111137
@testing.fix_random()
112138
class TestTensorSolve(unittest.TestCase):
139+
113140
@testing.for_dtypes("ifdFD")
114141
@testing.numpy_cupy_allclose(atol=0.02, type_check=has_support_aspect64())
115142
def test_tensorsolve(self, xp, dtype):
@@ -128,6 +155,7 @@ def test_tensorsolve(self, xp, dtype):
128155
)
129156
)
130157
class TestInv(unittest.TestCase):
158+
131159
@testing.for_dtypes("ifdFD")
132160
@_condition.retry(10)
133161
def check_x(self, a_shape, dtype):
@@ -137,7 +165,6 @@ def check_x(self, a_shape, dtype):
137165
a_gpu_copy = a_gpu.copy()
138166
result_cpu = numpy.linalg.inv(a_cpu)
139167
result_gpu = cupy.linalg.inv(a_gpu)
140-
141168
assert_dtype_allclose(result_gpu, result_cpu)
142169
testing.assert_array_equal(a_gpu_copy, a_gpu)
143170

@@ -167,6 +194,7 @@ def test_invalid_shape(self):
167194

168195

169196
class TestInvInvalid(unittest.TestCase):
197+
170198
@testing.for_dtypes("ifdFD")
171199
def test_inv(self, dtype):
172200
for xp in (numpy, cupy):
@@ -189,6 +217,7 @@ def test_batched_inv(self, dtype):
189217

190218

191219
class TestPinv(unittest.TestCase):
220+
192221
@testing.for_dtypes("ifdFD")
193222
@_condition.retry(10)
194223
def check_x(self, a_shape, rcond, dtype):
@@ -231,6 +260,7 @@ def test_pinv_size_0(self):
231260

232261

233262
class TestLstsq:
263+
234264
@testing.for_dtypes("ifdFD")
235265
@testing.numpy_cupy_allclose(atol=1e-3, type_check=has_support_aspect64())
236266
def check_lstsq_solution(
@@ -309,20 +339,18 @@ def test_invalid_shapes(self):
309339
self.check_invalid_shapes((3, 3), (2, 2))
310340
self.check_invalid_shapes((4, 3), (10, 3, 3))
311341

312-
# dpnp.linalg.lstsq() does not raise a FutureWarning
313-
# because dpnp did not have a previous implementation of dpnp.linalg.lstsq()
314-
# and there is no need to get rid of old deprecated behavior as numpy did.
315-
@pytest.mark.skip("No support of deprecated behavior")
342+
@testing.with_requires("numpy>=2.0")
316343
@testing.for_float_dtypes(no_float16=True)
317344
@testing.numpy_cupy_allclose(atol=1e-3)
318-
def test_warn_rcond(self, xp, dtype):
345+
def test_nowarn_rcond(self, xp, dtype):
319346
a = testing.shaped_random((3, 3), xp, dtype)
320347
b = testing.shaped_random((3,), xp, dtype)
321-
with testing.assert_warns(FutureWarning):
322-
return xp.linalg.lstsq(a, b)
348+
# FutureWarning is no longer emitted
349+
return xp.linalg.lstsq(a, b)
323350

324351

325352
class TestTensorInv(unittest.TestCase):
353+
326354
@testing.for_dtypes("ifdFD")
327355
@_condition.retry(10)
328356
def check_x(self, a_shape, ind, dtype):

0 commit comments

Comments
 (0)