Skip to content

Commit 442ba6f

Browse files
committed
Move _resolve_one_strong_one_weak_types and _resolve_one_strong_two_weak_types to _type_utils
1 parent 9ddecb0 commit 442ba6f

File tree

3 files changed

+120
-116
lines changed

3 files changed

+120
-116
lines changed

dpctl/tensor/_clip.py

Lines changed: 3 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -30,124 +30,15 @@
3030
_validate_dtype,
3131
)
3232
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
33-
from dpctl.tensor._type_utils import _can_cast, _to_device_supported_dtype
33+
from dpctl.tensor._type_utils import _can_cast
3434
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
3535

3636
from ._type_utils import (
37-
WeakComplexType,
38-
WeakIntegralType,
39-
_is_weak_dtype,
40-
_strong_dtype_num_kind,
41-
_weak_type_num_kind,
37+
_resolve_one_strong_one_weak_types,
38+
_resolve_one_strong_two_weak_types,
4239
)
4340

4441

45-
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
46-
"Resolves weak data types per NEP-0050,"
47-
"where the second and third arguments are"
48-
"permitted to be weak types"
49-
if _is_weak_dtype(st_dtype):
50-
raise ValueError
51-
if _is_weak_dtype(dtype1):
52-
if _is_weak_dtype(dtype2):
53-
kind_num1 = _weak_type_num_kind(dtype1)
54-
kind_num2 = _weak_type_num_kind(dtype2)
55-
st_kind_num = _strong_dtype_num_kind(st_dtype)
56-
57-
if kind_num1 > st_kind_num:
58-
if isinstance(dtype1, WeakIntegralType):
59-
ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
60-
elif isinstance(dtype1, WeakComplexType):
61-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
62-
ret_dtype1 = dpt.complex64
63-
ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
64-
else:
65-
ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
66-
else:
67-
ret_dtype1 = st_dtype
68-
69-
if kind_num2 > st_kind_num:
70-
if isinstance(dtype2, WeakIntegralType):
71-
ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
72-
elif isinstance(dtype2, WeakComplexType):
73-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
74-
ret_dtype2 = dpt.complex64
75-
ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
76-
else:
77-
ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
78-
else:
79-
ret_dtype2 = st_dtype
80-
81-
return ret_dtype1, ret_dtype2
82-
83-
max_dt_num_kind, max_dtype = max(
84-
[
85-
(_strong_dtype_num_kind(st_dtype), st_dtype),
86-
(_strong_dtype_num_kind(dtype2), dtype2),
87-
]
88-
)
89-
dt1_kind_num = _weak_type_num_kind(dtype1)
90-
if dt1_kind_num > max_dt_num_kind:
91-
if isinstance(dtype1, WeakIntegralType):
92-
return dpt.dtype(ti.default_device_int_type(dev)), dtype2
93-
if isinstance(dtype1, WeakComplexType):
94-
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
95-
return dpt.complex64, dtype2
96-
return (
97-
_to_device_supported_dtype(dpt.complex128, dev),
98-
dtype2,
99-
)
100-
return _to_device_supported_dtype(dpt.float64, dev), dtype2
101-
else:
102-
return max_dtype, dtype2
103-
elif _is_weak_dtype(dtype2):
104-
max_dt_num_kind, max_dtype = max(
105-
[
106-
(_strong_dtype_num_kind(st_dtype), st_dtype),
107-
(_strong_dtype_num_kind(dtype1), dtype1),
108-
]
109-
)
110-
dt2_kind_num = _weak_type_num_kind(dtype2)
111-
if dt2_kind_num > max_dt_num_kind:
112-
if isinstance(dtype2, WeakIntegralType):
113-
return dtype1, dpt.dtype(ti.default_device_int_type(dev))
114-
if isinstance(dtype2, WeakComplexType):
115-
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
116-
return dtype1, dpt.complex64
117-
return (
118-
dtype1,
119-
_to_device_supported_dtype(dpt.complex128, dev),
120-
)
121-
return dtype1, _to_device_supported_dtype(dpt.float64, dev)
122-
else:
123-
return dtype1, max_dtype
124-
else:
125-
# both are strong dtypes
126-
# return unmodified
127-
return dtype1, dtype2
128-
129-
130-
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
131-
"Resolves one weak data type with one strong data type per NEP-0050"
132-
if _is_weak_dtype(st_dtype):
133-
raise ValueError
134-
if _is_weak_dtype(dtype):
135-
st_kind_num = _strong_dtype_num_kind(st_dtype)
136-
kind_num = _weak_type_num_kind(dtype)
137-
if kind_num > st_kind_num:
138-
if isinstance(dtype, WeakIntegralType):
139-
return dpt.dtype(ti.default_device_int_type(dev))
140-
if isinstance(dtype, WeakComplexType):
141-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
142-
return dpt.complex64
143-
return _to_device_supported_dtype(dpt.complex128, dev)
144-
return _to_device_supported_dtype(dpt.float64, dev)
145-
else:
146-
return st_dtype
147-
else:
148-
return dtype
149-
150-
15142
def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
15243
"Checks if both types `arg1_dtype` and `arg2_dtype` can be"
15344
"cast to `res_dtype` according to the rule `safe`"

dpctl/tensor/_type_utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,112 @@ def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev):
450450
return o1_dtype, o2_dtype
451451

452452

453+
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
454+
"Resolves weak data types per NEP-0050,"
455+
"where the second and third arguments are"
456+
"permitted to be weak types"
457+
if _is_weak_dtype(st_dtype):
458+
raise ValueError
459+
if _is_weak_dtype(dtype1):
460+
if _is_weak_dtype(dtype2):
461+
kind_num1 = _weak_type_num_kind(dtype1)
462+
kind_num2 = _weak_type_num_kind(dtype2)
463+
st_kind_num = _strong_dtype_num_kind(st_dtype)
464+
465+
if kind_num1 > st_kind_num:
466+
if isinstance(dtype1, WeakIntegralType):
467+
ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
468+
elif isinstance(dtype1, WeakComplexType):
469+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
470+
ret_dtype1 = dpt.complex64
471+
ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
472+
else:
473+
ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
474+
else:
475+
ret_dtype1 = st_dtype
476+
477+
if kind_num2 > st_kind_num:
478+
if isinstance(dtype2, WeakIntegralType):
479+
ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
480+
elif isinstance(dtype2, WeakComplexType):
481+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
482+
ret_dtype2 = dpt.complex64
483+
ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
484+
else:
485+
ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
486+
else:
487+
ret_dtype2 = st_dtype
488+
489+
return ret_dtype1, ret_dtype2
490+
491+
max_dt_num_kind, max_dtype = max(
492+
[
493+
(_strong_dtype_num_kind(st_dtype), st_dtype),
494+
(_strong_dtype_num_kind(dtype2), dtype2),
495+
]
496+
)
497+
dt1_kind_num = _weak_type_num_kind(dtype1)
498+
if dt1_kind_num > max_dt_num_kind:
499+
if isinstance(dtype1, WeakIntegralType):
500+
return dpt.dtype(ti.default_device_int_type(dev)), dtype2
501+
if isinstance(dtype1, WeakComplexType):
502+
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
503+
return dpt.complex64, dtype2
504+
return (
505+
_to_device_supported_dtype(dpt.complex128, dev),
506+
dtype2,
507+
)
508+
return _to_device_supported_dtype(dpt.float64, dev), dtype2
509+
else:
510+
return max_dtype, dtype2
511+
elif _is_weak_dtype(dtype2):
512+
max_dt_num_kind, max_dtype = max(
513+
[
514+
(_strong_dtype_num_kind(st_dtype), st_dtype),
515+
(_strong_dtype_num_kind(dtype1), dtype1),
516+
]
517+
)
518+
dt2_kind_num = _weak_type_num_kind(dtype2)
519+
if dt2_kind_num > max_dt_num_kind:
520+
if isinstance(dtype2, WeakIntegralType):
521+
return dtype1, dpt.dtype(ti.default_device_int_type(dev))
522+
if isinstance(dtype2, WeakComplexType):
523+
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
524+
return dtype1, dpt.complex64
525+
return (
526+
dtype1,
527+
_to_device_supported_dtype(dpt.complex128, dev),
528+
)
529+
return dtype1, _to_device_supported_dtype(dpt.float64, dev)
530+
else:
531+
return dtype1, max_dtype
532+
else:
533+
# both are strong dtypes
534+
# return unmodified
535+
return dtype1, dtype2
536+
537+
538+
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
539+
"Resolves one weak data type with one strong data type per NEP-0050"
540+
if _is_weak_dtype(st_dtype):
541+
raise ValueError
542+
if _is_weak_dtype(dtype):
543+
st_kind_num = _strong_dtype_num_kind(st_dtype)
544+
kind_num = _weak_type_num_kind(dtype)
545+
if kind_num > st_kind_num:
546+
if isinstance(dtype, WeakIntegralType):
547+
return dpt.dtype(ti.default_device_int_type(dev))
548+
if isinstance(dtype, WeakComplexType):
549+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
550+
return dpt.complex64
551+
return _to_device_supported_dtype(dpt.complex128, dev)
552+
return _to_device_supported_dtype(dpt.float64, dev)
553+
else:
554+
return st_dtype
555+
else:
556+
return dtype
557+
558+
453559
class finfo_object:
454560
"""
455561
`numpy.finfo` subclass which returns Python floating-point scalars for
@@ -838,6 +944,8 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
838944
"_acceptance_fn_divide",
839945
"_acceptance_fn_negative",
840946
"_acceptance_fn_subtract",
947+
"_resolve_one_strong_one_weak_types",
948+
"_resolve_one_strong_two_weak_types",
841949
"_resolve_weak_types",
842950
"_resolve_weak_types_all_py_ints",
843951
"_weak_type_num_kind",

dpctl/tensor/_utility_functions.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@
2020
import dpctl.tensor._tensor_impl as ti
2121
import dpctl.tensor._tensor_reductions_impl as tri
2222
import dpctl.utils as du
23-
from dpctl.tensor._clip import (
24-
_resolve_one_strong_one_weak_types,
25-
_resolve_one_strong_two_weak_types,
26-
)
2723
from dpctl.tensor._elementwise_common import (
2824
_get_dtype,
2925
_get_queue_usm_type,
@@ -32,6 +28,10 @@
3228
)
3329

3430
from ._numpy_helper import normalize_axis_index, normalize_axis_tuple
31+
from ._type_utils import (
32+
_resolve_one_strong_one_weak_types,
33+
_resolve_one_strong_two_weak_types,
34+
)
3535

3636

3737
def _boolean_reduction(x, axis, keepdims, func):
@@ -159,6 +159,8 @@ def any(x, /, *, axis=None, keepdims=False):
159159

160160

161161
def _validate_diff_shape(sh1, sh2, axis):
162+
"""Utility for validating that two shapes `sh1` and `sh2`
163+
are possible to concatenate along `axis`."""
162164
if not sh2:
163165
# scalars will always be accepted
164166
return True
@@ -173,6 +175,9 @@ def _validate_diff_shape(sh1, sh2, axis):
173175

174176

175177
def _concat_diff_input(arr, axis, prepend, append):
178+
"""Concatenates `arr`, `prepend` and, `append` along `axis`,
179+
where `arr` is an array and `prepend` and `append` are
180+
any mixture of arrays and scalars."""
176181
if prepend is not None and append is not None:
177182
q1, x_usm_type = arr.sycl_queue, arr.usm_type
178183
q2, prepend_usm_type = _get_queue_usm_type(prepend)

0 commit comments

Comments
 (0)