Skip to content

Commit 0ffa47e

Browse files
Implement dpnp.common_type()
1 parent e0b7932 commit 0ffa47e

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"""
3939

4040

41+
import functools
4142
import math
4243
import operator
4344
import warnings
@@ -101,6 +102,7 @@ class UniqueInverseResult(NamedTuple):
101102
"broadcast_shapes",
102103
"broadcast_to",
103104
"can_cast",
105+
"common_type",
104106
"column_stack",
105107
"concat",
106108
"concatenate",
@@ -1310,6 +1312,62 @@ def can_cast(from_, to, casting="safe"):
13101312
return dpt.can_cast(dtype_from, to, casting=casting)
13111313

13121314

1315+
def common_type(*arrays):
1316+
"""
1317+
Return a scalar type which is common to the input arrays.
1318+
1319+
The return type will always be an inexact (i.e. floating point or complex)
1320+
scalar type, even if all the arrays are integer arrays.
1321+
If one of the inputs is an integer array, the minimum precision type
1322+
that is returned is determined by the device capabilities.
1323+
1324+
For full documentation refer to :obj:`numpy.common_type`.
1325+
1326+
Parameters
1327+
----------
1328+
arrays: {dpnp.ndarray, usm_ndarray}
1329+
Input arrays.
1330+
1331+
Returns
1332+
-------
1333+
out: data type
1334+
Data type object.
1335+
1336+
Examples
1337+
--------
1338+
>>> import dpnp as np
1339+
>>> np.common_type(np.arange(2, dtype=np.float32))
1340+
numpy.float32
1341+
>>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2))
1342+
numpy.float64
1343+
>>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0]))
1344+
numpy.complex128
1345+
1346+
"""
1347+
1348+
if len(arrays) == 0:
1349+
return (
1350+
dpnp.float16
1351+
if dpctl.select_default_device().has_aspect_fp16
1352+
else dpnp.float32
1353+
)
1354+
1355+
dpnp.check_supported_arrays_type(*arrays)
1356+
1357+
_, exec_q = get_usm_allocations(arrays)
1358+
default_float_dtype = dpnp.default_float_type(sycl_queue=exec_q)
1359+
dtypes = []
1360+
for a in arrays:
1361+
if a.dtype.kind == "b":
1362+
raise TypeError("can't get common type for non-numeric array")
1363+
if a.dtype.kind in "iu":
1364+
dtypes.append(default_float_dtype)
1365+
else:
1366+
dtypes.append(a.dtype)
1367+
1368+
return functools.reduce(numpy.promote_types, dtypes).type
1369+
1370+
13131371
def column_stack(tup):
13141372
"""
13151373
Stacks 1-D and 2-D arrays as columns into a 2-D array.

0 commit comments

Comments
 (0)