|
38 | 38 | """ |
39 | 39 |
|
40 | 40 |
|
| 41 | +import functools |
41 | 42 | import math |
42 | 43 | import operator |
43 | 44 | import warnings |
@@ -101,6 +102,7 @@ class UniqueInverseResult(NamedTuple): |
101 | 102 | "broadcast_shapes", |
102 | 103 | "broadcast_to", |
103 | 104 | "can_cast", |
| 105 | + "common_type", |
104 | 106 | "column_stack", |
105 | 107 | "concat", |
106 | 108 | "concatenate", |
@@ -1310,6 +1312,62 @@ def can_cast(from_, to, casting="safe"): |
1310 | 1312 | return dpt.can_cast(dtype_from, to, casting=casting) |
1311 | 1313 |
|
1312 | 1314 |
|
| 1315 | +def common_type(*arrays): |
| 1316 | + """ |
| 1317 | + Return a scalar type which is common to the input arrays. |
| 1318 | +
|
| 1319 | + The return type will always be an inexact (i.e. floating point or complex) |
| 1320 | + scalar type, even if all the arrays are integer arrays. |
| 1321 | + If one of the inputs is an integer array, the minimum precision type |
| 1322 | + that is returned is determined by the device capabilities. |
| 1323 | +
|
| 1324 | + For full documentation refer to :obj:`numpy.common_type`. |
| 1325 | +
|
| 1326 | + Parameters |
| 1327 | + ---------- |
| 1328 | + arrays: {dpnp.ndarray, usm_ndarray} |
| 1329 | + Input arrays. |
| 1330 | +
|
| 1331 | + Returns |
| 1332 | + ------- |
| 1333 | + out: data type |
| 1334 | + Data type object. |
| 1335 | +
|
| 1336 | + Examples |
| 1337 | + -------- |
| 1338 | + >>> import dpnp as np |
| 1339 | + >>> np.common_type(np.arange(2, dtype=np.float32)) |
| 1340 | + numpy.float32 |
| 1341 | + >>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2)) |
| 1342 | + numpy.float64 |
| 1343 | + >>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0])) |
| 1344 | + numpy.complex128 |
| 1345 | +
|
| 1346 | + """ |
| 1347 | + |
| 1348 | + if len(arrays) == 0: |
| 1349 | + return ( |
| 1350 | + dpnp.float16 |
| 1351 | + if dpctl.select_default_device().has_aspect_fp16 |
| 1352 | + else dpnp.float32 |
| 1353 | + ) |
| 1354 | + |
| 1355 | + dpnp.check_supported_arrays_type(*arrays) |
| 1356 | + |
| 1357 | + _, exec_q = get_usm_allocations(arrays) |
| 1358 | + default_float_dtype = dpnp.default_float_type(sycl_queue=exec_q) |
| 1359 | + dtypes = [] |
| 1360 | + for a in arrays: |
| 1361 | + if a.dtype.kind == "b": |
| 1362 | + raise TypeError("can't get common type for non-numeric array") |
| 1363 | + if a.dtype.kind in "iu": |
| 1364 | + dtypes.append(default_float_dtype) |
| 1365 | + else: |
| 1366 | + dtypes.append(a.dtype) |
| 1367 | + |
| 1368 | + return functools.reduce(numpy.promote_types, dtypes).type |
| 1369 | + |
| 1370 | + |
1313 | 1371 | def column_stack(tup): |
1314 | 1372 | """ |
1315 | 1373 | Stacks 1-D and 2-D arrays as columns into a 2-D array. |
|
0 commit comments