Skip to content

Commit f12771d

Browse files
author
Vahid Tavanashad
committed
updates regarding sign function
1 parent dca2bfd commit f12771d

File tree

2 files changed

+20
-26
lines changed

2 files changed

+20
-26
lines changed

dpnp/tests/test_mathematical.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2309,35 +2309,26 @@ def test_float_remainder_fmod_nans_inf(func, dtype, lhs, rhs):
23092309
assert_equal(result, expected)
23102310

23112311

2312+
@testing.with_requires("numpy>=2.0.0")
23122313
@pytest.mark.parametrize(
2313-
"data",
2314-
[[2, 0, -2], [1.1, -1.1]],
2315-
ids=["[2, 0, -2]", "[1.1, -1.1]"],
2316-
)
2317-
@pytest.mark.parametrize(
2318-
"dtype", get_all_dtypes(no_bool=True, no_unsigned=True)
2314+
"dtype", get_all_dtypes(no_none=True, no_unsigned=True)
23192315
)
2320-
def test_sign(data, dtype):
2321-
np_a = numpy.array(data, dtype=dtype)
2322-
dpnp_a = dpnp.array(data, dtype=dtype)
2316+
def test_sign(dtype):
2317+
a = generate_random_numpy_array((2, 3), dtype=dtype)
2318+
ia = dpnp.array(a, dtype=dtype)
23232319

2324-
result = dpnp.sign(dpnp_a)
2325-
expected = numpy.sign(np_a)
2326-
assert_dtype_allclose(result, expected)
2327-
2328-
# out keyword
2329-
if dtype is not None:
2330-
dp_out = dpnp.empty(expected.shape, dtype=expected.dtype)
2331-
result = dpnp.sign(dpnp_a, out=dp_out)
2332-
assert dp_out is result
2320+
if dtype == dpnp.bool:
2321+
assert_raises(TypeError, dpnp.sign, ia)
2322+
else:
2323+
result = dpnp.sign(ia)
2324+
expected = numpy.sign(a)
23332325
assert_dtype_allclose(result, expected)
23342326

2335-
2336-
def test_sign_boolean():
2337-
dpnp_a = dpnp.array([True, False])
2338-
2339-
with pytest.raises(TypeError):
2340-
dpnp.sign(dpnp_a)
2327+
# out keyword
2328+
iout = dpnp.empty(expected.shape, dtype=expected.dtype)
2329+
result = dpnp.sign(ia, out=iout)
2330+
assert iout is result
2331+
assert_dtype_allclose(result, expected)
23412332

23422333

23432334
@pytest.mark.parametrize(

dpnp/tests/test_strides.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_complex_dtypes,
1414
get_float_complex_dtypes,
1515
get_integer_dtypes,
16+
numpy_version,
1617
)
1718

1819

@@ -67,8 +68,10 @@ def test_1arg_support_complex(func, dtype, stride):
6768
x = generate_random_numpy_array(10, dtype=dtype)
6869
a, ia = x[::stride], dpnp.array(x)[::stride]
6970

70-
# dpnp default is stable=True
71-
kwargs = {"stable": True} if func == "argsort" else {}
71+
if numpy_version() < "2.0.0" and func in ["sign"]:
72+
pytest.skip("numpy definition is different for complex numbers.")
73+
# dpnp default is stable
74+
kwargs = {"kind": "stable"} if func == "argsort" else {}
7275
result = getattr(dpnp, func)(ia)
7376
expected = getattr(numpy, func)(a, **kwargs)
7477
assert_dtype_allclose(result, expected, factor=24)

0 commit comments

Comments
 (0)