Skip to content

Commit 496f949

Browse files
author
Vahid Tavanashad
committed
update vector_norm and matrix_norm tests for empty arrays and improve coverage
1 parent 4f6125c commit 496f949

File tree

1 file changed

+29
-34
lines changed

1 file changed

+29
-34
lines changed

dpnp/tests/test_linalg.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2298,49 +2298,44 @@ def test_matrix_norm(self, ord, keepdims):
22982298
expected = numpy.linalg.matrix_norm(a, ord=ord, keepdims=keepdims)
22992299
assert_dtype_allclose(result, expected)
23002300

2301-
@pytest.mark.parametrize(
2302-
"xp",
2303-
[
2304-
dpnp,
2305-
pytest.param(
2306-
numpy,
2307-
marks=pytest.mark.skipif(
2308-
numpy_version() < "2.3.0",
2309-
reason="numpy raises an error",
2310-
),
2311-
),
2312-
],
2313-
)
2301+
@testing.with_requires("numpy>=2.3")
23142302
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.int32])
23152303
@pytest.mark.parametrize(
23162304
"shape, axis", [[(2, 0), None], [(2, 0), (0, 1)], [(0, 2), (0, 1)]]
23172305
)
23182306
@pytest.mark.parametrize("ord", [None, "fro", "nuc", 1, 2, dpnp.inf])
2319-
def test_matrix_norm_empty(self, xp, dtype, shape, axis, ord):
2320-
x = xp.zeros(shape, dtype=dtype)
2321-
sc = dtype(0.0) if dtype == dpnp.float32 else 0.0
2322-
assert_equal(xp.linalg.norm(x, axis=axis, ord=ord), sc)
2307+
@pytest.mark.parametrize("keepdims", [True, False])
2308+
def test_matrix_norm_empty(self, dtype, shape, axis, ord, keepdims):
2309+
a = numpy.zeros(shape, dtype=dtype)
2310+
ia = dpnp.array(a)
2311+
result = dpnp.linalg.matrix_norm(
2312+
ia, axis=axis, ord=ord, keepdims=keepdims
2313+
)
2314+
expected = dpnp.linalg.matrix_norm(
2315+
a, axis=axis, ord=ord, keepdims=keepdims
2316+
)
2317+
assert_dtype_allclose(result, expected)
23232318

2324-
@pytest.mark.parametrize(
2325-
"xp",
2326-
[
2327-
dpnp,
2328-
pytest.param(
2329-
numpy,
2330-
marks=pytest.mark.skipif(
2331-
numpy_version() < "2.3.0",
2332-
reason="numpy raises an error",
2333-
),
2334-
),
2335-
],
2336-
)
2319+
@testing.with_requires("numpy>=2.3")
23372320
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.int32])
23382321
@pytest.mark.parametrize("axis", [None, 0])
23392322
@pytest.mark.parametrize("ord", [None, 1, 2, dpnp.inf])
2340-
def test_vector_norm_empty(self, xp, dtype, axis, ord):
2341-
x = xp.zeros(0, dtype=dtype)
2342-
sc = dtype(0.0) if dtype == dpnp.float32 else 0.0
2343-
assert_equal(xp.linalg.vector_norm(x, axis=axis, ord=ord), sc)
2323+
@pytest.mark.parametrize("keepdims", [True, False])
2324+
def test_vector_norm_empty(self, dtype, axis, ord, keepdims):
2325+
a = numpy.zeros(0, dtype=dtype)
2326+
ia = dpnp.array(a)
2327+
result = dpnp.linalg.vector_norm(
2328+
ia, axis=axis, ord=ord, keepdims=keepdims
2329+
)
2330+
expected = numpy.linalg.vector_norm(
2331+
a, axis=axis, ord=ord, keepdims=keepdims
2332+
)
2333+
assert_dtype_allclose(result, expected)
2334+
if keepdims:
2335+
# norm and vector_norm have different paths in dpnp when keepdims=True,
2336+
# to cover both of them test with norm as well
2337+
result = dpnp.linalg.norm(ia, axis=axis, ord=ord, keepdims=keepdims)
2338+
assert_dtype_allclose(result, expected)
23442339

23452340
@testing.with_requires("numpy>=2.0")
23462341
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)