Skip to content

Commit e6e66ad

Browse files
Align TestSolve with cupy tests
1 parent d140f56 commit e6e66ad

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

dpnp/tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def check_x(self, a_shape, b_shape, xp, dtype):
4747
testing.assert_array_equal(b_copy, b)
4848
return result
4949

50+
@testing.with_requires("numpy>=2.0")
5051
def test_solve(self):
5152
self.check_x((4, 4), (4,))
5253
self.check_x((5, 5), (5, 2))
@@ -55,14 +56,9 @@ def test_solve(self):
5556
self.check_x((0, 0), (0,))
5657
self.check_x((0, 0), (0, 2))
5758
self.check_x((0, 2, 2), (0, 2, 3))
58-
# In numpy 2.0 the broadcast ambiguity has been removed and now
59-
# b is treaded as a single vector if and only if it is 1-dimensional;
60-
# for other cases this signature must be followed
61-
# (..., m, m), (..., m, n) -> (..., m, n)
62-
# https://github.com/numpy/numpy/pull/25914
63-
if numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
64-
self.check_x((2, 3, 3), (3,))
65-
self.check_x((2, 5, 3, 3), (3,))
59+
# Allowed since numpy 2
60+
self.check_x((2, 3, 3), (3,))
61+
self.check_x((2, 5, 3, 3), (3,))
6662

6763
def check_shape(self, a_shape, b_shape, error_types):
6864
for xp, error_type in error_types.items():
@@ -81,6 +77,7 @@ def test_solve_singular_empty(self, xp):
8177
# LinAlgError("Singular matrix") is not raised
8278
return xp.linalg.solve(a, b)
8379

80+
@testing.with_requires("numpy>=2.0")
8481
def test_invalid_shape(self):
8582
linalg_errors = {
8683
numpy: numpy.linalg.LinAlgError,
@@ -98,10 +95,9 @@ def test_invalid_shape(self):
9895
self.check_shape((3, 3), (0,), value_errors)
9996
self.check_shape((0, 3, 4), (3,), linalg_errors)
10097
# Not allowed since numpy 2.0
101-
if numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
102-
self.check_shape((0, 2, 2), (0, 2), value_errors)
103-
self.check_shape((2, 4, 4), (2, 4), value_errors)
104-
self.check_shape((2, 3, 2, 2), (2, 3, 2), value_errors)
98+
self.check_shape((0, 2, 2), (0, 2), value_errors)
99+
self.check_shape((2, 4, 4), (2, 4), value_errors)
100+
self.check_shape((2, 3, 2, 2), (2, 3, 2), value_errors)
105101

106102

107103
@testing.parameterize(

0 commit comments

Comments
 (0)