Skip to content

Commit d60f5b6

Browse files
committed
Add dpnp.common_type implementation
1 parent 25d0ddd commit d60f5b6

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

dpnp/dpnp_iface.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"as_usm_ndarray",
6262
"check_limitations",
6363
"check_supported_arrays_type",
64+
"common_type",
6465
"default_float_type",
6566
"from_dlpack",
6667
"get_dpnp_descriptor",
@@ -406,6 +407,73 @@ def check_supported_arrays_type(*arrays, scalar_type=False, all_scalars=False):
406407
return True
407408

408409

410+
# determine the "minimum common type" for a group of arrays
411+
array_precision = {
412+
dpnp.float16: 0,
413+
dpnp.float32: 1,
414+
dpnp.float64: 2,
415+
dpnp.complex64: 3,
416+
dpnp.complex128: 4,
417+
}
418+
419+
array_type = {
420+
"float": {0: dpnp.float16, 1: dpnp.float32, 2: dpnp.float64},
421+
"complex": {3: dpnp.complex64, 4: dpnp.complex128},
422+
}
423+
424+
425+
def common_type(*arrays):
426+
"""
427+
Return a scalar type which is common to the input arrays.
428+
429+
The return type will always be an inexact (i.e. floating point) scalar
430+
type, even if all the arrays are integer arrays. If one of the inputs is
431+
an integer array, the minimum precision type that is returned is a
432+
64-bit floating point dtype.
433+
434+
For full documentation refer to :obj:`numpy.common_type`
435+
436+
Parameters
437+
----------
438+
array1, array2, ... : {dpnp.ndarray, usm_ndarray}
439+
Input arrays.
440+
441+
Returns
442+
-------
443+
out : data type code
444+
Data type code.
445+
446+
Examples
447+
--------
448+
>>> import dpnp as np
449+
>>> np.common_type(np.arange(2, dtype=np.float32))
450+
<class 'numpy.float32'>
451+
>>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2))
452+
<class 'numpy.float64'>
453+
>>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0]))
454+
<class 'numpy.complex128'>
455+
456+
"""
457+
dpnp.check_supported_arrays_type(*arrays)
458+
459+
is_complex = False
460+
max_precision = 0
461+
462+
for a in arrays:
463+
t = a.dtype.type
464+
465+
if dpnp.issubdtype(t, dpnp.complexfloating):
466+
is_complex = True
467+
if dpnp.issubdtype(t, dpnp.integer):
468+
t = dpnp.float64
469+
470+
max_precision = max(max_precision, array_precision.get(t, 0))
471+
472+
if is_complex:
473+
return array_type["complex"].get(max_precision, dpnp.complex128)
474+
return array_type["float"].get(max_precision, dpnp.float64)
475+
476+
409477
def default_float_type(device=None, sycl_queue=None):
410478
"""
411479
Return a floating type used by default in DPNP depending on device

tests/third_party/cupy/test_type_routines.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def test_can_cast(self, xp, from_dtype, to_dtype):
4646
return ret
4747

4848

49-
@pytest.mark.skip("dpnp.common_type() is not implemented yet")
5049
class TestCommonType(unittest.TestCase):
5150
@testing.numpy_cupy_equal()
5251
def test_common_type_empty(self, xp):

0 commit comments

Comments
 (0)