Skip to content

Commit a8fd61c

Browse files
author
Vahid Tavanashad
committed
update dpnp.fix docstring
1 parent 008dc96 commit a8fd61c

File tree

3 files changed

+17
-29
lines changed

3 files changed

+17
-29
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,8 +1750,8 @@ def ediff1d(ary, to_end=None, to_begin=None):
17501750
-------
17511751
out : dpnp.ndarray
17521752
An array with the rounded values and with the same dimensions as the input.
1753-
The returned array will have the default floating point data type for the
1754-
device where `a` is allocated.
1753+
The returned array will have a floating point data type that input can cast
1754+
to it safely considering device capabilities.
17551755
If `out` is ``None`` then a float array is returned with the rounded values.
17561756
Otherwise the result is stored there and the return value `out` is
17571757
a reference to that array.

dpnp/dpnp_iface_statistics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,8 +797,8 @@ def cov(
797797
Default: ``None``.
798798
dtype : {None, str, dtype object}, optional
799799
Data-type of the result. By default, the return data-type will have
800-
at least floating point type based on the capabilities of the device on
801-
which the input arrays reside.
800+
the default floating point data-type of the device on which the input
801+
arrays reside.
802802
803803
Default: ``None``.
804804

dpnp/tests/test_mathematical.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,6 @@
3636
from .third_party.cupy import testing
3737

3838

39-
def _get_output_data_type(dtype):
40-
"""Return a data type specified by input `dtype` and device capabilities."""
41-
dtype_float16 = any(
42-
dpnp.issubdtype(dtype, t) for t in (dpnp.bool, dpnp.int8, dpnp.uint8)
43-
)
44-
dtype_float32 = any(
45-
dpnp.issubdtype(dtype, t) for t in (dpnp.int16, dpnp.uint16)
46-
)
47-
if dtype_float16:
48-
out_dtype = dpnp.float16 if has_support_aspect16() else dpnp.float32
49-
elif dtype_float32:
50-
out_dtype = dpnp.float32
51-
elif dpnp.issubdtype(dtype, dpnp.complexfloating):
52-
out_dtype = dpnp.complex64
53-
if has_support_aspect64() and dtype != dpnp.complex64:
54-
out_dtype = dpnp.complex128
55-
else:
56-
out_dtype = dpnp.float32
57-
if has_support_aspect64() and dtype != dpnp.float32:
58-
out_dtype = dpnp.float64
59-
60-
return out_dtype
61-
62-
6339
@pytest.mark.parametrize("deg", [True, False])
6440
class TestAngle:
6541
def test_angle_bool(self, deg):
@@ -775,6 +751,16 @@ def test_errors(self):
775751

776752

777753
class TestFix:
754+
def get_output_data_type(self, dtype, device):
755+
if dpnp.can_cast(dtype, dpnp.float16):
756+
return dpnp.float16
757+
if dpnp.can_cast(dtype, dpnp.float32):
758+
return dpnp.float32
759+
if dpnp.can_cast(dtype, dpnp.float64):
760+
return dpnp.float64
761+
762+
return dpnp.default_float_type(device)
763+
778764
@pytest.mark.parametrize(
779765
"dt", get_all_dtypes(no_none=True, no_complex=True)
780766
)
@@ -802,7 +788,9 @@ def test_out(self, a_dt):
802788
)
803789
ia = dpnp.array(a)
804790

805-
out_dt = _get_output_data_type(a.dtype)
791+
# numpy output has the same dtype as input
792+
# dpnp output always has a floating point dtype
793+
out_dt = self.get_output_data_type(a.dtype, a.device)
806794
out = numpy.zeros_like(a, dtype=out_dt)
807795
iout = dpnp.array(out)
808796

0 commit comments

Comments
 (0)