Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"""


import functools
import math
import operator
import warnings
Expand Down Expand Up @@ -101,6 +102,7 @@ class UniqueInverseResult(NamedTuple):
"broadcast_shapes",
"broadcast_to",
"can_cast",
"common_type",
"column_stack",
"concat",
"concatenate",
Expand Down Expand Up @@ -1310,6 +1312,62 @@ def can_cast(from_, to, casting="safe"):
return dpt.can_cast(dtype_from, to, casting=casting)


def common_type(*arrays):
"""
Return a scalar type which is common to the input arrays.

The return type will always be an inexact (i.e. floating point or complex)
scalar type, even if all the arrays are integer arrays.
If one of the inputs is an integer array, the minimum precision type
that is returned is determined by the device capabilities.

For full documentation refer to :obj:`numpy.common_type`.

Parameters
----------
arrays: {dpnp.ndarray, usm_ndarray}
Input arrays.

Returns
-------
out: data type
Data type object.

Examples
--------
>>> import dpnp as np
>>> np.common_type(np.arange(2, dtype=np.float32))
numpy.float32
>>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2))
numpy.float64
>>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0]))
numpy.complex128

"""

if len(arrays) == 0:
return (
dpnp.float16
if dpctl.select_default_device().has_aspect_fp16
else dpnp.float32
)

dpnp.check_supported_arrays_type(*arrays)

_, exec_q = get_usm_allocations(arrays)
default_float_dtype = dpnp.default_float_type(sycl_queue=exec_q)
dtypes = []
for a in arrays:
if a.dtype.kind == "b":
raise TypeError("can't get common type for non-numeric array")
if a.dtype.kind in "iu":
dtypes.append(default_float_dtype)
else:
dtypes.append(a.dtype)

return functools.reduce(numpy.promote_types, dtypes).type


def column_stack(tup):
"""
Stacks 1-D and 2-D arrays as columns into a 2-D array.
Expand Down
15 changes: 13 additions & 2 deletions dpnp/tests/third_party/cupy/test_type_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import dpnp as cupy
from dpnp.tests.helper import has_support_aspect64
from dpnp.tests.helper import has_support_aspect16, has_support_aspect64
from dpnp.tests.third_party.cupy import testing


Expand Down Expand Up @@ -47,13 +47,17 @@ def test_can_cast(self, xp, from_dtype, to_dtype):
return ret


@pytest.mark.skip("dpnp.common_type() is not implemented yet")
class TestCommonType(unittest.TestCase):

@testing.numpy_cupy_equal()
def test_common_type_empty(self, xp):
ret = xp.common_type()
assert type(ret) is type
# NumPy always returns float16 for empty input,
# but dpnp returns float32 if the device does not support
# 16-bit precision floating point operations
if xp is numpy and not has_support_aspect16():
return xp.float32
return ret

@testing.for_all_dtypes(no_bool=True)
Expand All @@ -62,6 +66,11 @@ def test_common_type_single_argument(self, xp, dtype):
array = _generate_type_routines_input(xp, dtype, "array")
ret = xp.common_type(array)
assert type(ret) is type
# NumPy promotes integer types to float64,
# but dpnp may return float32 if the device does not support
# 64-bit precision floating point operations.
if xp is numpy and not has_support_aspect64():
return xp.float32
return ret

@testing.for_all_dtypes_combination(
Expand All @@ -73,6 +82,8 @@ def test_common_type_two_arguments(self, xp, dtype1, dtype2):
array2 = _generate_type_routines_input(xp, dtype2, "array")
ret = xp.common_type(array1, array2)
assert type(ret) is type
if xp is numpy and not has_support_aspect64():
return xp.float32
return ret

@testing.for_all_dtypes()
Expand Down
Loading