|
32 | 32 | This module provides public type interface file for the library |
33 | 33 | """ |
34 | 34 |
|
| 35 | +import functools |
| 36 | + |
| 37 | +import dpctl |
35 | 38 | import dpctl.tensor as dpt |
36 | 39 | import numpy |
37 | 40 |
|
| 41 | +import dpnp |
| 42 | + |
38 | 43 | from .dpnp_array import dpnp_array |
39 | 44 |
|
| 45 | +# pylint: disable=no-name-in-module |
| 46 | +from .dpnp_utils import get_usm_allocations |
| 47 | + |
40 | 48 | __all__ = [ |
41 | 49 | "bool", |
42 | 50 | "bool_", |
43 | 51 | "byte", |
44 | 52 | "cdouble", |
| 53 | + "common_type", |
45 | 54 | "complex128", |
46 | 55 | "complex64", |
47 | 56 | "complexfloating", |
|
145 | 154 | pi = numpy.pi |
146 | 155 |
|
147 | 156 |
|
| 157 | +def common_type(*arrays): |
| 158 | + """ |
| 159 | + Return a scalar type which is common to the input arrays. |
| 160 | +
|
| 161 | + The return type will always be an inexact (i.e. floating point or complex) |
| 162 | + scalar type, even if all the arrays are integer arrays. |
| 163 | + If one of the inputs is an integer array, the minimum precision type |
| 164 | + that is returned is the default floating point data type for the device |
| 165 | + where the input arrays are allocated. |
| 166 | +
|
| 167 | + For full documentation refer to :obj:`numpy.common_type`. |
| 168 | +
|
| 169 | + Parameters |
| 170 | + ---------- |
| 171 | + arrays: {dpnp.ndarray, usm_ndarray} |
| 172 | + Input arrays. |
| 173 | +
|
| 174 | + Returns |
| 175 | + ------- |
| 176 | + out: data type |
| 177 | + Data type object. |
| 178 | +
|
| 179 | + See Also |
| 180 | + -------- |
| 181 | + :obj:`dpnp.dtype` : Create a data type object. |
| 182 | +
|
| 183 | + Examples |
| 184 | + -------- |
| 185 | + >>> import dpnp as np |
| 186 | + >>> np.common_type(np.arange(2, dtype=np.float32)) |
| 187 | + numpy.float32 |
| 188 | + >>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2)) |
| 189 | + numpy.float64 # may vary |
| 190 | + >>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0])) |
| 191 | + numpy.complex128 # may vary |
| 192 | +
|
| 193 | + """ |
| 194 | + |
| 195 | + if len(arrays) == 0: |
| 196 | + return ( |
| 197 | + dpnp.float16 |
| 198 | + if dpctl.select_default_device().has_aspect_fp16 |
| 199 | + else dpnp.float32 |
| 200 | + ) |
| 201 | + |
| 202 | + dpnp.check_supported_arrays_type(*arrays) |
| 203 | + |
| 204 | + _, exec_q = get_usm_allocations(arrays) |
| 205 | + default_float_dtype = dpnp.default_float_type(sycl_queue=exec_q) |
| 206 | + dtypes = [] |
| 207 | + for a in arrays: |
| 208 | + if not dpnp.issubdtype(a.dtype, dpnp.number): |
| 209 | + raise TypeError("can't get common type for non-numeric array") |
| 210 | + if dpnp.issubdtype(a.dtype, dpnp.integer): |
| 211 | + dtypes.append(default_float_dtype) |
| 212 | + else: |
| 213 | + dtypes.append(a.dtype) |
| 214 | + |
| 215 | + return functools.reduce(numpy.promote_types, dtypes).type |
| 216 | + |
| 217 | + |
148 | 218 | # pylint: disable=redefined-outer-name |
149 | 219 | def finfo(dtype): |
150 | 220 | """ |
|
0 commit comments