Skip to content

Commit 5b8cdd7

Browse files
author
Vahid Tavanashad
committed
updates for obtaining out dtype
1 parent 261f933 commit 5b8cdd7

File tree

3 files changed

+36
-42
lines changed

3 files changed

+36
-42
lines changed

dpnp/tests/test_binary_ufuncs.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
has_support_aspect16,
2424
numpy_version,
2525
)
26-
from .test_umath import _get_output_data_type
2726

2827
"""
2928
The scope includes tests with only functions which are instances of
@@ -195,10 +194,7 @@ def test_divide(self, dtype):
195194
expected = numpy.divide(a, b)
196195

197196
ia, ib = dpnp.array(a), dpnp.array(b)
198-
if numpy.issubdtype(dtype, numpy.integer) or dtype == dpnp.bool:
199-
out_dtype = map_dtype_to_device(dpnp.float64, ia.sycl_device)
200-
else:
201-
out_dtype = _get_output_data_type(dtype)
197+
out_dtype = map_dtype_to_device(expected.dtype, ia.sycl_device)
202198
iout = dpnp.empty(expected.shape, dtype=out_dtype)
203199
result = dpnp.divide(ia, ib, out=iout)
204200

dpnp/tests/test_mathematical.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import dpnp
1919
from dpnp.dpnp_array import dpnp_array
20+
from dpnp.dpnp_utils import map_dtype_to_device
2021

2122
from .helper import (
2223
assert_dtype_allclose,
@@ -32,10 +33,33 @@
3233
has_support_aspect64,
3334
numpy_version,
3435
)
35-
from .test_umath import _get_output_data_type
3636
from .third_party.cupy import testing
3737

3838

39+
def _get_output_data_type(dtype):
40+
"""Return a data type specified by input `dtype` and device capabilities."""
41+
dtype_float16 = any(
42+
dpnp.issubdtype(dtype, t) for t in (dpnp.bool, dpnp.int8, dpnp.uint8)
43+
)
44+
dtype_float32 = any(
45+
dpnp.issubdtype(dtype, t) for t in (dpnp.int16, dpnp.uint16)
46+
)
47+
if dtype_float16:
48+
out_dtype = dpnp.float16 if has_support_aspect16() else dpnp.float32
49+
elif dtype_float32:
50+
out_dtype = dpnp.float32
51+
elif dpnp.issubdtype(dtype, dpnp.complexfloating):
52+
out_dtype = dpnp.complex64
53+
if has_support_aspect64() and dtype != dpnp.complex64:
54+
out_dtype = dpnp.complex128
55+
else:
56+
out_dtype = dpnp.float32
57+
if has_support_aspect64() and dtype != dpnp.float32:
58+
out_dtype = dpnp.float64
59+
60+
return out_dtype
61+
62+
3963
@pytest.mark.parametrize("deg", [True, False])
4064
class TestAngle:
4165
def test_angle_bool(self, deg):
@@ -2323,7 +2347,7 @@ def test_hypot(self, dtype):
23232347
expected = numpy.hypot(a, b)
23242348

23252349
ia, ib = dpnp.array(a), dpnp.array(b)
2326-
out_dt = _get_output_data_type(dtype)
2350+
out_dt = map_dtype_to_device(expected.dtype, ia.sycl_device)
23272351
iout = dpnp.empty(expected.shape, dtype=out_dt)
23282352
result = dpnp.hypot(ia, ib, out=iout)
23292353

dpnp/tests/test_umath.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010

1111
import dpnp
12+
from dpnp.dpnp_utils import map_dtype_to_device
1213

1314
from .helper import (
1415
assert_dtype_allclose,
@@ -108,30 +109,6 @@ def test_umaths(test_cases):
108109
assert_allclose(result, expected, rtol=1e-6)
109110

110111

111-
def _get_output_data_type(dtype):
112-
"""Return a data type specified by input `dtype` and device capabilities."""
113-
dtype_float16 = any(
114-
dpnp.issubdtype(dtype, t) for t in (dpnp.bool, dpnp.int8, dpnp.uint8)
115-
)
116-
dtype_float32 = any(
117-
dpnp.issubdtype(dtype, t) for t in (dpnp.int16, dpnp.uint16)
118-
)
119-
if dtype_float16:
120-
dt_out = dpnp.float16 if has_support_aspect16() else dpnp.float32
121-
elif dtype_float32:
122-
dt_out = dpnp.float32
123-
elif dpnp.issubdtype(dtype, dpnp.complexfloating):
124-
dt_out = dpnp.complex64
125-
if has_support_aspect64() and dtype != dpnp.complex64:
126-
dt_out = dpnp.complex128
127-
else:
128-
dt_out = dpnp.float32
129-
if has_support_aspect64() and dtype != dpnp.float32:
130-
dt_out = dpnp.float64
131-
132-
return dt_out
133-
134-
135112
class TestArctan2:
136113
@pytest.mark.parametrize(
137114
"dtype", get_all_dtypes(no_none=True, no_complex=True)
@@ -142,10 +119,10 @@ def test_arctan2(self, dtype):
142119
expected = numpy.arctan2(a, b)
143120

144121
ia, ib = dpnp.array(a), dpnp.array(b)
145-
dt_out = _get_output_data_type(dtype)
122+
dt_out = map_dtype_to_device(expected.dtype, ia.sycl_device)
146123
iout = dpnp.empty(expected.shape, dtype=dt_out)
147-
result = dpnp.arctan2(ia, ib, out=iout)
148124

125+
result = dpnp.arctan2(ia, ib, out=iout)
149126
assert result is iout
150127
assert_dtype_allclose(result, expected)
151128

@@ -188,7 +165,7 @@ def test_copysign(self, dtype):
188165
expected = numpy.copysign(a, b)
189166

190167
ia, ib = dpnp.array(a), dpnp.array(b)
191-
dt_out = _get_output_data_type(dtype)
168+
dt_out = map_dtype_to_device(expected.dtype, ia.sycl_device)
192169
iout = dpnp.empty(expected.shape, dtype=dt_out)
193170
result = dpnp.copysign(ia, ib, out=iout)
194171

@@ -307,7 +284,7 @@ def test_logaddexp(self, dtype):
307284
expected = numpy.logaddexp(a, b)
308285

309286
ia, ib = dpnp.array(a), dpnp.array(b)
310-
dt_out = _get_output_data_type(dtype)
287+
dt_out = map_dtype_to_device(expected.dtype, ia.sycl_device)
311288
iout = dpnp.empty(expected.shape, dtype=dt_out)
312289
result = dpnp.logaddexp(ia, ib, out=iout)
313290

@@ -450,7 +427,7 @@ def test_reciprocal(self, dtype):
450427
expected = numpy.reciprocal(a)
451428

452429
ia = dpnp.array(a)
453-
dt_out = _get_output_data_type(dtype)
430+
dt_out = map_dtype_to_device(expected.dtype, ia.sycl_device)
454431
iout = dpnp.empty(expected.shape, dtype=dt_out)
455432
result = dpnp.reciprocal(ia, out=iout)
456433

@@ -500,7 +477,7 @@ def test_basic(self, func_params, dtype):
500477
expected = getattr(numpy, func)(a)
501478

502479
ia = dpnp.array(a)
503-
dt_out = _get_output_data_type(dtype)
480+
dt_out = map_dtype_to_device(expected.dtype, ia.sycl_device)
504481
iout = dpnp.empty(expected.shape, dtype=dt_out)
505482
result = getattr(dpnp, func)(ia, out=iout)
506483
assert result is iout
@@ -591,7 +568,7 @@ def func_params(self, request):
591568
@pytest.mark.filterwarnings("ignore:overflow encountered:RuntimeWarning")
592569
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
593570
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
594-
def test_out(self, func_params, dtype):
571+
def test_basic(self, func_params, dtype):
595572
func = func_params["func"]
596573
values = func_params["values"]
597574
a = generate_random_numpy_array(
@@ -600,10 +577,7 @@ def test_out(self, func_params, dtype):
600577
expected = getattr(numpy, func)(a)
601578

602579
ia = dpnp.array(a)
603-
if func == "square":
604-
dt_out = numpy.int8 if dtype == dpnp.bool else dtype
605-
else:
606-
dt_out = _get_output_data_type(dtype)
580+
dt_out = map_dtype_to_device(expected.dtype, ia.sycl_device)
607581
iout = dpnp.empty(expected.shape, dtype=dt_out)
608582
result = getattr(dpnp, func)(ia, out=iout)
609583

0 commit comments

Comments
 (0)