Skip to content

Commit 01dcdd0

Browse files
Move common_type() to dpnp_iface_manipulation.py
1 parent cfb2e99 commit 01dcdd0

File tree

2 files changed

+66
-59
lines changed

2 files changed

+66
-59
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
"""
3939

4040

41-
import functools
4241
import math
4342
import operator
4443
import warnings
@@ -102,7 +101,6 @@ class UniqueInverseResult(NamedTuple):
102101
"broadcast_shapes",
103102
"broadcast_to",
104103
"can_cast",
105-
"common_type",
106104
"column_stack",
107105
"concat",
108106
"concatenate",
@@ -1312,63 +1310,6 @@ def can_cast(from_, to, casting="safe"):
13121310
return dpt.can_cast(dtype_from, to, casting=casting)
13131311

13141312

1315-
def common_type(*arrays):
1316-
"""
1317-
Return a scalar type which is common to the input arrays.
1318-
1319-
The return type will always be an inexact (i.e. floating point or complex)
1320-
scalar type, even if all the arrays are integer arrays.
1321-
If one of the inputs is an integer array, the minimum precision type
1322-
that is returned is the default floating point data type for the device
1323-
where the input arrays are allocated.
1324-
1325-
For full documentation refer to :obj:`numpy.common_type`.
1326-
1327-
Parameters
1328-
----------
1329-
arrays: {dpnp.ndarray, usm_ndarray}
1330-
Input arrays.
1331-
1332-
Returns
1333-
-------
1334-
out: data type
1335-
Data type object.
1336-
1337-
Examples
1338-
--------
1339-
>>> import dpnp as np
1340-
>>> np.common_type(np.arange(2, dtype=np.float32))
1341-
numpy.float32
1342-
>>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2))
1343-
numpy.float64 # may vary
1344-
>>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0]))
1345-
numpy.complex128 # may vary
1346-
1347-
"""
1348-
1349-
if len(arrays) == 0:
1350-
return (
1351-
dpnp.float16
1352-
if dpctl.select_default_device().has_aspect_fp16
1353-
else dpnp.float32
1354-
)
1355-
1356-
dpnp.check_supported_arrays_type(*arrays)
1357-
1358-
_, exec_q = get_usm_allocations(arrays)
1359-
default_float_dtype = dpnp.default_float_type(sycl_queue=exec_q)
1360-
dtypes = []
1361-
for a in arrays:
1362-
if not dpnp.issubdtype(a.dtype, dpnp.number):
1363-
raise TypeError("can't get common type for non-numeric array")
1364-
if dpnp.issubdtype(a.dtype, dpnp.integer):
1365-
dtypes.append(default_float_dtype)
1366-
else:
1367-
dtypes.append(a.dtype)
1368-
1369-
return functools.reduce(numpy.promote_types, dtypes).type
1370-
1371-
13721313
def column_stack(tup):
13731314
"""
13741315
Stacks 1-D and 2-D arrays as columns into a 2-D array.

dpnp/dpnp_iface_types.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,25 @@
3232
This module provides public type interface file for the library
3333
"""
3434

35+
import functools
36+
37+
import dpctl
3538
import dpctl.tensor as dpt
3639
import numpy
3740

41+
import dpnp
42+
3843
from .dpnp_array import dpnp_array
3944

45+
# pylint: disable=no-name-in-module
46+
from .dpnp_utils import get_usm_allocations
47+
4048
__all__ = [
4149
"bool",
4250
"bool_",
4351
"byte",
4452
"cdouble",
53+
"common_type",
4554
"complex128",
4655
"complex64",
4756
"complexfloating",
@@ -145,6 +154,63 @@
145154
pi = numpy.pi
146155

147156

157+
def common_type(*arrays):
158+
"""
159+
Return a scalar type which is common to the input arrays.
160+
161+
The return type will always be an inexact (i.e. floating point or complex)
162+
scalar type, even if all the arrays are integer arrays.
163+
If one of the inputs is an integer array, the minimum precision type
164+
that is returned is the default floating point data type for the device
165+
where the input arrays are allocated.
166+
167+
For full documentation refer to :obj:`numpy.common_type`.
168+
169+
Parameters
170+
----------
171+
arrays: {dpnp.ndarray, usm_ndarray}
172+
Input arrays.
173+
174+
Returns
175+
-------
176+
out: data type
177+
Data type object.
178+
179+
Examples
180+
--------
181+
>>> import dpnp as np
182+
>>> np.common_type(np.arange(2, dtype=np.float32))
183+
numpy.float32
184+
>>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2))
185+
numpy.float64 # may vary
186+
>>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0]))
187+
numpy.complex128 # may vary
188+
189+
"""
190+
191+
if len(arrays) == 0:
192+
return (
193+
dpnp.float16
194+
if dpctl.select_default_device().has_aspect_fp16
195+
else dpnp.float32
196+
)
197+
198+
dpnp.check_supported_arrays_type(*arrays)
199+
200+
_, exec_q = get_usm_allocations(arrays)
201+
default_float_dtype = dpnp.default_float_type(sycl_queue=exec_q)
202+
dtypes = []
203+
for a in arrays:
204+
if not dpnp.issubdtype(a.dtype, dpnp.number):
205+
raise TypeError("can't get common type for non-numeric array")
206+
if dpnp.issubdtype(a.dtype, dpnp.integer):
207+
dtypes.append(default_float_dtype)
208+
else:
209+
dtypes.append(a.dtype)
210+
211+
return functools.reduce(numpy.promote_types, dtypes).type
212+
213+
148214
# pylint: disable=redefined-outer-name
149215
def finfo(dtype):
150216
"""

0 commit comments

Comments
 (0)