diff --git a/dpnp/dpnp_iface_types.py b/dpnp/dpnp_iface_types.py index 20ffa55f3c4b..6834ad466276 100644 --- a/dpnp/dpnp_iface_types.py +++ b/dpnp/dpnp_iface_types.py @@ -32,16 +32,25 @@ This module provides public type interface file for the library """ +import functools + +import dpctl import dpctl.tensor as dpt import numpy +import dpnp + from .dpnp_array import dpnp_array +# pylint: disable=no-name-in-module +from .dpnp_utils import get_usm_allocations + __all__ = [ "bool", "bool_", "byte", "cdouble", + "common_type", "complex128", "complex64", "complexfloating", @@ -145,6 +154,67 @@ pi = numpy.pi +def common_type(*arrays): + """ + Return a scalar type which is common to the input arrays. + + The return type will always be an inexact (i.e. floating point or complex) + scalar type, even if all the arrays are integer arrays. + If one of the inputs is an integer array, the minimum precision type + that is returned is the default floating point data type for the device + where the input arrays are allocated. + + For full documentation refer to :obj:`numpy.common_type`. + + Parameters + ---------- + arrays: {dpnp.ndarray, usm_ndarray} + Input arrays. + + Returns + ------- + out: data type + Data type object. + + See Also + -------- + :obj:`dpnp.dtype` : Create a data type object. + + Examples + -------- + >>> import dpnp as np + >>> np.common_type(np.arange(2, dtype=np.float32)) + numpy.float32 + >>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2)) + numpy.float64 # may vary + >>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0])) + numpy.complex128 # may vary + + """ + + if len(arrays) == 0: + return ( + dpnp.float16 + if dpctl.select_default_device().has_aspect_fp16 + else dpnp.float32 + ) + + dpnp.check_supported_arrays_type(*arrays) + + _, exec_q = get_usm_allocations(arrays) + default_float_dtype = dpnp.default_float_type(sycl_queue=exec_q) + dtypes = [] + for a in arrays: + if not dpnp.issubdtype(a.dtype, dpnp.number): + raise TypeError("can't get common type for non-numeric array") + if dpnp.issubdtype(a.dtype, dpnp.integer): + dtypes.append(default_float_dtype) + else: + dtypes.append(a.dtype) + + return functools.reduce(numpy.promote_types, dtypes).type + + # pylint: disable=redefined-outer-name def finfo(dtype): """ diff --git a/dpnp/tests/third_party/cupy/test_type_routines.py b/dpnp/tests/third_party/cupy/test_type_routines.py index 9e59baa7971d..e35b40d90841 100644 --- a/dpnp/tests/third_party/cupy/test_type_routines.py +++ b/dpnp/tests/third_party/cupy/test_type_routines.py @@ -4,7 +4,7 @@ import pytest import dpnp as cupy -from dpnp.tests.helper import has_support_aspect64 +from dpnp.tests.helper import has_support_aspect16, has_support_aspect64 from dpnp.tests.third_party.cupy import testing @@ -47,13 +47,17 @@ def test_can_cast(self, xp, from_dtype, to_dtype): return ret -@pytest.mark.skip("dpnp.common_type() is not implemented yet") class TestCommonType(unittest.TestCase): @testing.numpy_cupy_equal() def test_common_type_empty(self, xp): ret = xp.common_type() assert type(ret) is type + # NumPy always returns float16 for empty input, + # but dpnp returns float32 if the device does not support + # 16-bit precision floating point operations + if xp is numpy and not has_support_aspect16(): + return xp.float32 return ret @testing.for_all_dtypes(no_bool=True) @@ -62,6 +66,11 @@ def test_common_type_single_argument(self, xp, dtype): array = _generate_type_routines_input(xp, dtype, "array") ret = xp.common_type(array) assert type(ret) is type + # NumPy promotes integer types to float64, + # but dpnp may return float32 if the device does not support + # 64-bit precision floating point operations. + if xp is numpy and not has_support_aspect64(): + return xp.float32 return ret @testing.for_all_dtypes_combination( @@ -73,6 +82,8 @@ def test_common_type_two_arguments(self, xp, dtype1, dtype2): array2 = _generate_type_routines_input(xp, dtype2, "array") ret = xp.common_type(array1, array2) assert type(ret) is type + if xp is numpy and not has_support_aspect64(): + return xp.float32 return ret @testing.for_all_dtypes()