Skip to content

Commit 3001a9e

Browse files
author
Vahid Tavanashad
committed
add parametrize xp
1 parent 7ad964c commit 3001a9e

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

dpnp/tests/test_linalg.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
has_support_aspect64,
2626
is_cpu_device,
2727
is_cuda_device,
28+
numpy_version,
2829
)
2930
from .third_party.cupy import testing
3031

@@ -2299,27 +2300,48 @@ def test_matrix_norm(self, ord, keepdims):
22992300
expected = numpy.linalg.matrix_norm(a, ord=ord, keepdims=keepdims)
23002301
assert_dtype_allclose(result, expected)
23012302

2303+
@pytest.mark.parametrize(
2304+
"xp",
2305+
[
2306+
dpnp,
2307+
pytest.param(
2308+
numpy,
2309+
marks=pytest.mark.skipif(
2310+
numpy_version() < "2.3.0",
2311+
reason="numpy raises an error",
2312+
),
2313+
),
2314+
],
2315+
)
23022316
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.int32])
23032317
@pytest.mark.parametrize(
23042318
"shape_axis", [[(2, 0), None], [(2, 0), (0, 1)], [(0, 2), (0, 1)]]
23052319
)
23062320
@pytest.mark.parametrize("ord", [None, "fro", "nuc", 1, 2, dpnp.inf])
2307-
def test_matrix_norm_empty(self, dtype, shape_axis, ord):
2321+
def test_matrix_norm_empty(self, xp, dtype, shape_axis, ord):
23082322
shape, axis = shape_axis[0], shape_axis[1]
2309-
x = dpnp.zeros(shape, dtype=dtype)
2310-
2311-
# TODO: when similar changes in numpy are available,
2312-
# instead of assert_equal with zero, we should compare with numpy
2313-
assert_equal(dpnp.linalg.norm(x, axis=axis, ord=ord), 0)
2323+
x = xp.zeros(shape, dtype=dtype)
2324+
assert_equal(xp.linalg.norm(x, axis=axis, ord=ord), 0)
23142325

2326+
@pytest.mark.parametrize(
2327+
"xp",
2328+
[
2329+
dpnp,
2330+
pytest.param(
2331+
numpy,
2332+
marks=pytest.mark.skipif(
2333+
numpy_version() < "2.3.0",
2334+
reason="numpy raises an error",
2335+
),
2336+
),
2337+
],
2338+
)
23152339
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.int32])
23162340
@pytest.mark.parametrize("axis", [None, 0])
23172341
@pytest.mark.parametrize("ord", [None, 1, 2, dpnp.inf])
2318-
def test_vector_norm_empty(self, dtype, axis, ord):
2319-
x = dpnp.zeros(0, dtype=dtype)
2320-
# TODO: when similar changes in numpy are available,
2321-
# instead of assert_equal with zero, we should compare with numpy
2322-
assert_equal(dpnp.linalg.vector_norm(x, axis=axis, ord=ord), 0)
2342+
def test_vector_norm_empty(self, xp, dtype, axis, ord):
2343+
x = xp.zeros(0, dtype=dtype)
2344+
assert_equal(xp.linalg.vector_norm(x, axis=axis, ord=ord), 0)
23232345

23242346
@testing.with_requires("numpy>=2.0")
23252347
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)