|
61 | 61 | "as_usm_ndarray", |
62 | 62 | "check_limitations", |
63 | 63 | "check_supported_arrays_type", |
| 64 | + "common_type", |
64 | 65 | "default_float_type", |
65 | 66 | "from_dlpack", |
66 | 67 | "get_dpnp_descriptor", |
@@ -406,6 +407,73 @@ def check_supported_arrays_type(*arrays, scalar_type=False, all_scalars=False): |
406 | 407 | return True |
407 | 408 |
|
408 | 409 |
|
| 410 | +# determine the "minimum common type" for a group of arrays |
| 411 | +array_precision = { |
| 412 | + dpnp.float16: 0, |
| 413 | + dpnp.float32: 1, |
| 414 | + dpnp.float64: 2, |
| 415 | + dpnp.complex64: 3, |
| 416 | + dpnp.complex128: 4, |
| 417 | +} |
| 418 | + |
| 419 | +array_type = { |
| 420 | + "float": {0: dpnp.float16, 1: dpnp.float32, 2: dpnp.float64}, |
| 421 | + "complex": {3: dpnp.complex64, 4: dpnp.complex128}, |
| 422 | +} |
| 423 | + |
| 424 | + |
| 425 | +def common_type(*arrays): |
| 426 | + """ |
| 427 | + Return a scalar type which is common to the input arrays. |
| 428 | +
|
| 429 | + The return type will always be an inexact (i.e. floating point) scalar |
| 430 | + type, even if all the arrays are integer arrays. If one of the inputs is |
| 431 | + an integer array, the minimum precision type that is returned is a |
| 432 | + 64-bit floating point dtype. |
| 433 | +
|
| 434 | + For full documentation refer to :obj:`numpy.common_type` |
| 435 | +
|
| 436 | + Parameters |
| 437 | + ---------- |
| 438 | + array1, array2, ... : {dpnp.ndarray, usm_ndarray} |
| 439 | + Input arrays. |
| 440 | +
|
| 441 | + Returns |
| 442 | + ------- |
| 443 | + out : data type code |
| 444 | + Data type code. |
| 445 | +
|
| 446 | + Examples |
| 447 | + -------- |
| 448 | + >>> import dpnp as np |
| 449 | + >>> np.common_type(np.arange(2, dtype=np.float32)) |
| 450 | + <class 'numpy.float32'> |
| 451 | + >>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2)) |
| 452 | + <class 'numpy.float64'> |
| 453 | + >>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0])) |
| 454 | + <class 'numpy.complex128'> |
| 455 | +
|
| 456 | + """ |
| 457 | + dpnp.check_supported_arrays_type(*arrays) |
| 458 | + |
| 459 | + is_complex = False |
| 460 | + max_precision = 0 |
| 461 | + |
| 462 | + for a in arrays: |
| 463 | + t = a.dtype.type |
| 464 | + |
| 465 | + if dpnp.issubdtype(t, dpnp.complexfloating): |
| 466 | + is_complex = True |
| 467 | + if dpnp.issubdtype(t, dpnp.integer): |
| 468 | + t = dpnp.float64 |
| 469 | + |
| 470 | + max_precision = max(max_precision, array_precision.get(t, 0)) |
| 471 | + |
| 472 | + if is_complex: |
| 473 | + return array_type["complex"].get(max_precision, dpnp.complex128) |
| 474 | + return array_type["float"].get(max_precision, dpnp.float64) |
| 475 | + |
| 476 | + |
409 | 477 | def default_float_type(device=None, sycl_queue=None): |
410 | 478 | """ |
411 | 479 | Return a floating type used by default in DPNP depending on device |
|
0 commit comments