Skip to content

Commit c2f2971

Browse files
Align test_linalg.py with master
1 parent 08f5ca3 commit c2f2971

File tree

1 file changed

+22
-27
lines changed

1 file changed

+22
-27
lines changed

dpnp/tests/test_linalg.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
)
1616

1717
import dpnp
18-
import dpnp.linalg
1918

2019
from .helper import (
2120
assert_dtype_allclose,
@@ -279,15 +278,12 @@ def test_cholesky_errors(self):
279278

280279

281280
class TestCond:
282-
def setup_method(self):
283-
numpy.random.seed(70)
281+
_norms = [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"]
284282

285283
@pytest.mark.parametrize(
286-
"shape", [(0, 4, 4), (4, 0, 3, 3)], ids=["(0, 5, 3)", "(4, 0, 2, 3)"]
287-
)
288-
@pytest.mark.parametrize(
289-
"p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"]
284+
"shape", [(0, 4, 4), (4, 0, 3, 3)], ids=["(0, 4, 4)", "(4, 0, 3, 3)"]
290285
)
286+
@pytest.mark.parametrize("p", _norms)
291287
def test_empty(self, shape, p):
292288
a = numpy.empty(shape)
293289
ia = dpnp.array(a)
@@ -296,26 +292,27 @@ def test_empty(self, shape, p):
296292
expected = numpy.linalg.cond(a, p=p)
297293
assert_dtype_allclose(result, expected)
298294

295+
# TODO: uncomment once numpy 2.3.3 release is published
296+
# @testing.with_requires("numpy>=2.3.3")
299297
@pytest.mark.parametrize(
300298
"dtype", get_all_dtypes(no_none=True, no_bool=True)
301299
)
302300
@pytest.mark.parametrize(
303301
"shape", [(4, 4), (2, 4, 3, 3)], ids=["(4, 4)", "(2, 4, 3, 3)"]
304302
)
305-
@pytest.mark.parametrize(
306-
"p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"]
307-
)
303+
@pytest.mark.parametrize("p", _norms)
308304
def test_basic(self, dtype, shape, p):
309305
a = generate_random_numpy_array(shape, dtype)
310306
ia = dpnp.array(a)
311307

312308
result = dpnp.linalg.cond(ia, p=p)
313309
expected = numpy.linalg.cond(a, p=p)
310+
# TODO: remove when numpy#29333 is released
311+
if numpy_version() < "2.3.3":
312+
expected = expected.real
314313
assert_dtype_allclose(result, expected, factor=16)
315314

316-
@pytest.mark.parametrize(
317-
"p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"]
318-
)
315+
@pytest.mark.parametrize("p", _norms)
319316
def test_bool(self, p):
320317
a = numpy.array([[True, True], [True, False]])
321318
ia = dpnp.array(a)
@@ -324,9 +321,7 @@ def test_bool(self, p):
324321
expected = numpy.linalg.cond(a, p=p)
325322
assert_dtype_allclose(result, expected)
326323

327-
@pytest.mark.parametrize(
328-
"p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"]
329-
)
324+
@pytest.mark.parametrize("p", _norms)
330325
def test_nan_to_inf(self, p):
331326
a = numpy.zeros((2, 2))
332327
ia = dpnp.array(a)
@@ -344,9 +339,7 @@ def test_nan_to_inf(self, p):
344339
else:
345340
assert_raises(dpnp.linalg.LinAlgError, dpnp.linalg.cond, ia, p=p)
346341

347-
@pytest.mark.parametrize(
348-
"p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"]
349-
)
342+
@pytest.mark.parametrize("p", _norms)
350343
@pytest.mark.parametrize(
351344
"stride",
352345
[(-2, -3, 2, -2), (-2, 4, -4, -4), (2, 3, 4, 4), (-1, 3, 3, -3)],
@@ -358,21 +351,23 @@ def test_nan_to_inf(self, p):
358351
],
359352
)
360353
def test_strided(self, p, stride):
361-
A = numpy.random.rand(6, 8, 10, 10)
362-
B = dpnp.asarray(A)
354+
A = generate_random_numpy_array(
355+
(6, 8, 10, 10), seed_value=70, low=0, high=1
356+
)
357+
iA = dpnp.array(A)
363358
slices = tuple(slice(None, None, stride[i]) for i in range(A.ndim))
364-
a = A[slices]
365-
b = B[slices]
359+
a, ia = A[slices], iA[slices]
366360

367-
result = dpnp.linalg.cond(b, p=p)
361+
result = dpnp.linalg.cond(ia, p=p)
368362
expected = numpy.linalg.cond(a, p=p)
369363
assert_dtype_allclose(result, expected, factor=24)
370364

371-
def test_error(self):
365+
@pytest.mark.parametrize("xp", [dpnp, numpy])
366+
def test_error(self, xp):
372367
# cond is not defined on empty arrays
373-
ia = dpnp.empty((2, 0))
368+
a = xp.empty((2, 0))
374369
with pytest.raises(ValueError):
375-
dpnp.linalg.cond(ia, p=1)
370+
xp.linalg.cond(a, p=1)
376371

377372

378373
class TestDet:

0 commit comments

Comments
 (0)