Skip to content

Commit 5e1a355

Browse files
author
Vahid Tavanashad
committed
add new tests
1 parent fec42b1 commit 5e1a355

File tree

2 files changed

+55
-29
lines changed

2 files changed

+55
-29
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def __call__(self, x, deg=False, out=None, order="K"):
513513

514514

515515
class DPNPFix(DPNPUnaryFunc):
516-
"""Class that implements dpnp.real unary element-wise functions."""
516+
"""Class that implements dpnp.fix unary element-wise functions."""
517517

518518
def __init__(
519519
self,
@@ -534,14 +534,16 @@ def __call__(self, x, out=None, order="K"):
534534
pass # pass to raise error in main implementation
535535
elif dpnp.issubdtype(x.dtype, dpnp.inexact):
536536
pass # for inexact types, pass to calculate in the backend
537-
elif out is not None and (
538-
not dpnp.is_supported_array_type(out) or out.dtype != x.dtype
539-
):
537+
elif out is not None and not dpnp.is_supported_array_type(out):
540538
pass # pass to raise error in main implementation
539+
elif out is not None and out.dtype != x.dtype:
540+
raise ValueError(
541+
f"Output array of type {x.dtype} is needed, got {out.dtype}"
542+
)
541543
else:
542544
# for exact types, return the input
543545
if out is None:
544-
return dpnp.asarray(x, copy=True)
546+
return dpnp.copy(x, order=order)
545547

546548
if isinstance(out, dpt.usm_ndarray):
547549
out = dpnp_array._create_from_usm_ndarray(out)

dpnp/tests/test_mathematical.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2227,19 +2227,37 @@ def test_complex(self, func, xp, dt):
22272227

22282228
@testing.with_requires("numpy>=2.1.0")
22292229
@pytest.mark.parametrize(
2230-
"dt", get_all_dtypes(no_none=True, no_complex=True)
2230+
"dt_in", get_all_dtypes(no_none=True, no_complex=True)
22312231
)
2232-
def test_out(self, func, dt):
2233-
a = generate_random_numpy_array(10, dt)
2234-
# TODO: use dt_out = dt when dpctl#2030 is fixed
2235-
dt_out = numpy.int8 if dt == dpnp.bool else dt
2232+
@pytest.mark.parametrize(
2233+
"dt_out", get_all_dtypes(no_none=True, no_complex=True)
2234+
)
2235+
def test_out(self, func, dt_in, dt_out):
2236+
a = generate_random_numpy_array(10, dt_in)
22362237
out = numpy.empty(a.shape, dtype=dt_out)
22372238
ia, iout = dpnp.array(a), dpnp.array(out)
22382239

2239-
expected = getattr(numpy, func)(a, out=out)
2240-
result = getattr(dpnp, func)(ia, out=iout)
2241-
assert result is iout
2242-
assert_array_equal(result, expected)
2240+
if dt_in != dt_out:
2241+
if numpy.can_cast(dt_in, dt_out, casting="same_kind"):
2242+
# NumPy allows "same_kind" casting, dpnp does not
2243+
if func != "fix" and dt_in == dpnp.bool and dt_out == dpnp.int8:
2244+
# TODO: get rid of w/a when dpctl#2030 is fixed
2245+
pass
2246+
else:
2247+
assert_raises(ValueError, getattr(dpnp, func), ia, out=iout)
2248+
else:
2249+
assert_raises(ValueError, getattr(dpnp, func), ia, out=iout)
2250+
assert_raises(TypeError, getattr(numpy, func), a, out=out)
2251+
else:
2252+
if func != "fix" and dt_in == dpnp.bool:
2253+
# TODO: get rid of w/a when dpctl#2030 is fixed
2254+
out = out.astype(numpy.int8)
2255+
iout = iout.astype(dpnp.int8)
2256+
2257+
expected = getattr(numpy, func)(a, out=out)
2258+
result = getattr(dpnp, func)(ia, out=iout)
2259+
assert result is iout
2260+
assert_array_equal(result, expected)
22432261

22442262
@pytest.mark.skipif(not has_support_aspect16(), reason="no fp16 support")
22452263
def test_out_float16(self, func):
@@ -2252,22 +2270,22 @@ def test_out_float16(self, func):
22522270
assert result is iout
22532271
assert_array_equal(result, expected)
22542272

2255-
@pytest.mark.parametrize("xp", [numpy, dpnp])
22562273
@pytest.mark.parametrize(
2257-
"dt_out", get_all_dtypes(no_none=True, no_complex=True)[:-1]
2274+
"dt", get_all_dtypes(no_none=True, no_complex=True)
22582275
)
2259-
def test_invalid_dtype(self, func, xp, dt_out):
2260-
dt_in = get_all_dtypes(no_none=True, no_complex=True)[-1]
2261-
a = xp.arange(10, dtype=dt_in)
2262-
out = xp.empty(10, dtype=dt_out)
2263-
if dt_out == numpy.float32 and dt_in == numpy.float64:
2264-
if xp == dpnp:
2265-
# NumPy allows "same_kind" casting, dpnp does not
2266-
assert_raises(ValueError, getattr(dpnp, func), a, out=out)
2267-
else:
2268-
assert_raises(
2269-
(ValueError, TypeError), getattr(xp, func), a, out=out
2270-
)
2276+
def test_out_usm_ndarray(self, func, dt):
2277+
a = generate_random_numpy_array(10, dt)
2278+
out = numpy.empty(a.shape, dtype=dt)
2279+
ia, usm_out = dpnp.array(a), dpt.asarray(out)
2280+
2281+
if func != "fix" and dt == dpnp.bool:
2282+
# TODO: get rid of w/a when dpctl#2030 is fixed
2283+
out = out.astype(numpy.int8)
2284+
usm_out = dpt.asarray(usm_out, dtype=dpnp.int8)
2285+
2286+
expected = getattr(numpy, func)(a, out=out)
2287+
result = getattr(dpnp, func)(ia, out=usm_out)
2288+
assert_array_equal(result, expected)
22712289

22722290
@pytest.mark.parametrize("xp", [numpy, dpnp])
22732291
@pytest.mark.parametrize(
@@ -2278,9 +2296,15 @@ def test_invalid_shape(self, func, xp, shape):
22782296
out = xp.empty(shape, dtype=xp.float32)
22792297
assert_raises(ValueError, getattr(xp, func), a, out=out)
22802298

2281-
def test_scalar(self, func):
2299+
def test_error(self, func):
2300+
# scalar, unsupported input
22822301
assert_raises(TypeError, getattr(dpnp, func), -3.4)
22832302

2303+
# unsupported out
2304+
a = dpnp.array([1, 2, 3])
2305+
out = numpy.empty_like(3, dtype=a.dtype)
2306+
assert_raises(TypeError, getattr(dpnp, func), a, out=out)
2307+
22842308

22852309
class TestHypot:
22862310
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)