Skip to content

Commit 183f252

Browse files
committed
Update linalg_tests/test_solve.py
1 parent f59daaf commit 183f252

File tree

1 file changed

+46
-21
lines changed

1 file changed

+46
-21
lines changed

dpnp/tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
@testing.parameterize(
1616
*testing.product(
1717
{
18+
# "batched_gesv_limit": [None, 0],
1819
"order": ["C", "F"],
1920
}
2021
)
2122
)
2223
@testing.fix_random()
2324
class TestSolve(unittest.TestCase):
24-
# TODO: add get_batched_gesv_limit
25+
2526
# def setUp(self):
2627
# if self.batched_gesv_limit is not None:
2728
# self.old_limit = get_batched_gesv_limit()
@@ -32,6 +33,7 @@ class TestSolve(unittest.TestCase):
3233
# set_batched_gesv_limit(self.old_limit)
3334

3435
@testing.for_dtypes("ifdFD")
36+
# TODO(kataoka): Fix contiguity
3537
@testing.numpy_cupy_allclose(
3638
atol=1e-3, contiguous_check=False, type_check=has_support_aspect64()
3739
)
@@ -47,6 +49,7 @@ def check_x(self, a_shape, b_shape, xp, dtype):
4749
testing.assert_array_equal(b_copy, b)
4850
return result
4951

52+
@testing.with_requires("numpy>=2.0")
5053
def test_solve(self):
5154
self.check_x((4, 4), (4,))
5255
self.check_x((5, 5), (5, 2))
@@ -55,15 +58,9 @@ def test_solve(self):
5558
self.check_x((0, 0), (0,))
5659
self.check_x((0, 0), (0, 2))
5760
self.check_x((0, 2, 2), (0, 2, 3))
58-
# In numpy 2.0 the broadcast ambiguity has been removed and now
59-
# b is treaded as a single vector if and only if it is 1-dimensional;
60-
# for other cases this signature must be followed
61-
# (..., m, m), (..., m, n) -> (..., m, n)
62-
# https://github.com/numpy/numpy/pull/25914
63-
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
64-
self.check_x((2, 4, 4), (2, 4))
65-
self.check_x((2, 3, 2, 2), (2, 3, 2))
66-
self.check_x((0, 2, 2), (0, 2))
61+
# Allowed since numpy 2
62+
self.check_x((2, 3, 3), (3,))
63+
self.check_x((2, 5, 3, 3), (3,))
6764

6865
def check_shape(self, a_shape, b_shape, error_types):
6966
for xp, error_type in error_types.items():
@@ -76,12 +73,14 @@ def check_shape(self, a_shape, b_shape, error_types):
7673
# NumPy with OpenBLAS returns an empty array
7774
# while numpy with OneMKL raises LinAlgError
7875
@pytest.mark.skip("Undefined behavior")
76+
@testing.numpy_cupy_allclose()
7977
def test_solve_singular_empty(self, xp):
8078
a = xp.zeros((3, 3)) # singular
8179
b = xp.empty((3, 0)) # nrhs = 0
8280
# LinAlgError("Singular matrix") is not raised
8381
return xp.linalg.solve(a, b)
8482

83+
@testing.with_requires("numpy>=2.0")
8584
def test_invalid_shape(self):
8685
linalg_errors = {
8786
numpy: numpy.linalg.LinAlgError,
@@ -96,11 +95,35 @@ def test_invalid_shape(self):
9695
self.check_shape((3, 3), (2,), value_errors)
9796
self.check_shape((3, 3), (2, 2), value_errors)
9897
self.check_shape((3, 3, 4), (3,), linalg_errors)
99-
# Since numpy >= 2.0, this case does not raise an error
100-
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
101-
self.check_shape((2, 3, 3), (3,), value_errors)
10298
self.check_shape((3, 3), (0,), value_errors)
10399
self.check_shape((0, 3, 4), (3,), linalg_errors)
100+
self.check_shape((3, 3), (), value_errors)
101+
# Not allowed since numpy 2
102+
self.check_shape(
103+
(0, 2, 2),
104+
(
105+
0,
106+
2,
107+
),
108+
value_errors,
109+
)
110+
self.check_shape(
111+
(2, 4, 4),
112+
(
113+
2,
114+
4,
115+
),
116+
value_errors,
117+
)
118+
self.check_shape(
119+
(2, 3, 2, 2),
120+
(
121+
2,
122+
3,
123+
2,
124+
),
125+
value_errors,
126+
)
104127

105128

106129
@testing.parameterize(
@@ -113,6 +136,7 @@ def test_invalid_shape(self):
113136
)
114137
@testing.fix_random()
115138
class TestTensorSolve(unittest.TestCase):
139+
116140
@testing.for_dtypes("ifdFD")
117141
@testing.numpy_cupy_allclose(atol=0.02, type_check=has_support_aspect64())
118142
def test_tensorsolve(self, xp, dtype):
@@ -131,6 +155,7 @@ def test_tensorsolve(self, xp, dtype):
131155
)
132156
)
133157
class TestInv(unittest.TestCase):
158+
134159
@testing.for_dtypes("ifdFD")
135160
@_condition.retry(10)
136161
def check_x(self, a_shape, dtype):
@@ -140,7 +165,6 @@ def check_x(self, a_shape, dtype):
140165
a_gpu_copy = a_gpu.copy()
141166
result_cpu = numpy.linalg.inv(a_cpu)
142167
result_gpu = cupy.linalg.inv(a_gpu)
143-
144168
assert_dtype_allclose(result_gpu, result_cpu)
145169
testing.assert_array_equal(a_gpu_copy, a_gpu)
146170

@@ -170,6 +194,7 @@ def test_invalid_shape(self):
170194

171195

172196
class TestInvInvalid(unittest.TestCase):
197+
173198
@testing.for_dtypes("ifdFD")
174199
def test_inv(self, dtype):
175200
for xp in (numpy, cupy):
@@ -192,6 +217,7 @@ def test_batched_inv(self, dtype):
192217

193218

194219
class TestPinv(unittest.TestCase):
220+
195221
@testing.for_dtypes("ifdFD")
196222
@_condition.retry(10)
197223
def check_x(self, a_shape, rcond, dtype):
@@ -234,6 +260,7 @@ def test_pinv_size_0(self):
234260

235261

236262
class TestLstsq:
263+
237264
@testing.for_dtypes("ifdFD")
238265
@testing.numpy_cupy_allclose(atol=1e-3, type_check=has_support_aspect64())
239266
def check_lstsq_solution(
@@ -312,20 +339,18 @@ def test_invalid_shapes(self):
312339
self.check_invalid_shapes((3, 3), (2, 2))
313340
self.check_invalid_shapes((4, 3), (10, 3, 3))
314341

315-
# dpnp.linalg.lstsq() does not raise a FutureWarning
316-
# because dpnp did not have a previous implementation of dpnp.linalg.lstsq()
317-
# and there is no need to get rid of old deprecated behavior as numpy did.
318-
@pytest.mark.skip("No support of deprecated behavior")
342+
@testing.with_requires("numpy>=2.0")
319343
@testing.for_float_dtypes(no_float16=True)
320344
@testing.numpy_cupy_allclose(atol=1e-3)
321-
def test_warn_rcond(self, xp, dtype):
345+
def test_nowarn_rcond(self, xp, dtype):
322346
a = testing.shaped_random((3, 3), xp, dtype)
323347
b = testing.shaped_random((3,), xp, dtype)
324-
with testing.assert_warns(FutureWarning):
325-
return xp.linalg.lstsq(a, b)
348+
# FutureWarning is no longer emitted
349+
return xp.linalg.lstsq(a, b)
326350

327351

328352
class TestTensorInv(unittest.TestCase):
353+
329354
@testing.for_dtypes("ifdFD")
330355
@_condition.retry(10)
331356
def check_x(self, a_shape, ind, dtype):

0 commit comments

Comments
 (0)