Skip to content

Commit 78121cf

Browse files
committed
Move _resolve_one_strong_one_weak_types and _resolve_one_strong_two_weak_types to _type_utils
1 parent 6982339 commit 78121cf

File tree

3 files changed

+120
-116
lines changed

3 files changed

+120
-116
lines changed

dpctl/tensor/_clip.py

Lines changed: 3 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -30,124 +30,15 @@
3030
_validate_dtype,
3131
)
3232
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
33-
from dpctl.tensor._type_utils import _can_cast, _to_device_supported_dtype
33+
from dpctl.tensor._type_utils import _can_cast
3434
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
3535

3636
from ._type_utils import (
37-
WeakComplexType,
38-
WeakIntegralType,
39-
_is_weak_dtype,
40-
_strong_dtype_num_kind,
41-
_weak_type_num_kind,
37+
_resolve_one_strong_one_weak_types,
38+
_resolve_one_strong_two_weak_types,
4239
)
4340

4441

45-
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
46-
"Resolves weak data types per NEP-0050,"
47-
"where the second and third arguments are"
48-
"permitted to be weak types"
49-
if _is_weak_dtype(st_dtype):
50-
raise ValueError
51-
if _is_weak_dtype(dtype1):
52-
if _is_weak_dtype(dtype2):
53-
kind_num1 = _weak_type_num_kind(dtype1)
54-
kind_num2 = _weak_type_num_kind(dtype2)
55-
st_kind_num = _strong_dtype_num_kind(st_dtype)
56-
57-
if kind_num1 > st_kind_num:
58-
if isinstance(dtype1, WeakIntegralType):
59-
ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
60-
elif isinstance(dtype1, WeakComplexType):
61-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
62-
ret_dtype1 = dpt.complex64
63-
ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
64-
else:
65-
ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
66-
else:
67-
ret_dtype1 = st_dtype
68-
69-
if kind_num2 > st_kind_num:
70-
if isinstance(dtype2, WeakIntegralType):
71-
ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
72-
elif isinstance(dtype2, WeakComplexType):
73-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
74-
ret_dtype2 = dpt.complex64
75-
ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
76-
else:
77-
ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
78-
else:
79-
ret_dtype2 = st_dtype
80-
81-
return ret_dtype1, ret_dtype2
82-
83-
max_dt_num_kind, max_dtype = max(
84-
[
85-
(_strong_dtype_num_kind(st_dtype), st_dtype),
86-
(_strong_dtype_num_kind(dtype2), dtype2),
87-
]
88-
)
89-
dt1_kind_num = _weak_type_num_kind(dtype1)
90-
if dt1_kind_num > max_dt_num_kind:
91-
if isinstance(dtype1, WeakIntegralType):
92-
return dpt.dtype(ti.default_device_int_type(dev)), dtype2
93-
if isinstance(dtype1, WeakComplexType):
94-
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
95-
return dpt.complex64, dtype2
96-
return (
97-
_to_device_supported_dtype(dpt.complex128, dev),
98-
dtype2,
99-
)
100-
return _to_device_supported_dtype(dpt.float64, dev), dtype2
101-
else:
102-
return max_dtype, dtype2
103-
elif _is_weak_dtype(dtype2):
104-
max_dt_num_kind, max_dtype = max(
105-
[
106-
(_strong_dtype_num_kind(st_dtype), st_dtype),
107-
(_strong_dtype_num_kind(dtype1), dtype1),
108-
]
109-
)
110-
dt2_kind_num = _weak_type_num_kind(dtype2)
111-
if dt2_kind_num > max_dt_num_kind:
112-
if isinstance(dtype2, WeakIntegralType):
113-
return dtype1, dpt.dtype(ti.default_device_int_type(dev))
114-
if isinstance(dtype2, WeakComplexType):
115-
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
116-
return dtype1, dpt.complex64
117-
return (
118-
dtype1,
119-
_to_device_supported_dtype(dpt.complex128, dev),
120-
)
121-
return dtype1, _to_device_supported_dtype(dpt.float64, dev)
122-
else:
123-
return dtype1, max_dtype
124-
else:
125-
# both are strong dtypes
126-
# return unmodified
127-
return dtype1, dtype2
128-
129-
130-
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
131-
"Resolves one weak data type with one strong data type per NEP-0050"
132-
if _is_weak_dtype(st_dtype):
133-
raise ValueError
134-
if _is_weak_dtype(dtype):
135-
st_kind_num = _strong_dtype_num_kind(st_dtype)
136-
kind_num = _weak_type_num_kind(dtype)
137-
if kind_num > st_kind_num:
138-
if isinstance(dtype, WeakIntegralType):
139-
return dpt.dtype(ti.default_device_int_type(dev))
140-
if isinstance(dtype, WeakComplexType):
141-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
142-
return dpt.complex64
143-
return _to_device_supported_dtype(dpt.complex128, dev)
144-
return _to_device_supported_dtype(dpt.float64, dev)
145-
else:
146-
return st_dtype
147-
else:
148-
return dtype
149-
150-
15142
def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
15243
"Checks if both types `arg1_dtype` and `arg2_dtype` can be"
15344
"cast to `res_dtype` according to the rule `safe`"

dpctl/tensor/_type_utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,112 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
445445
return o1_dtype, o2_dtype
446446

447447

448+
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
449+
"Resolves weak data types per NEP-0050,"
450+
"where the second and third arguments are"
451+
"permitted to be weak types"
452+
if _is_weak_dtype(st_dtype):
453+
raise ValueError
454+
if _is_weak_dtype(dtype1):
455+
if _is_weak_dtype(dtype2):
456+
kind_num1 = _weak_type_num_kind(dtype1)
457+
kind_num2 = _weak_type_num_kind(dtype2)
458+
st_kind_num = _strong_dtype_num_kind(st_dtype)
459+
460+
if kind_num1 > st_kind_num:
461+
if isinstance(dtype1, WeakIntegralType):
462+
ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
463+
elif isinstance(dtype1, WeakComplexType):
464+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
465+
ret_dtype1 = dpt.complex64
466+
ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
467+
else:
468+
ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
469+
else:
470+
ret_dtype1 = st_dtype
471+
472+
if kind_num2 > st_kind_num:
473+
if isinstance(dtype2, WeakIntegralType):
474+
ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
475+
elif isinstance(dtype2, WeakComplexType):
476+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
477+
ret_dtype2 = dpt.complex64
478+
ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
479+
else:
480+
ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
481+
else:
482+
ret_dtype2 = st_dtype
483+
484+
return ret_dtype1, ret_dtype2
485+
486+
max_dt_num_kind, max_dtype = max(
487+
[
488+
(_strong_dtype_num_kind(st_dtype), st_dtype),
489+
(_strong_dtype_num_kind(dtype2), dtype2),
490+
]
491+
)
492+
dt1_kind_num = _weak_type_num_kind(dtype1)
493+
if dt1_kind_num > max_dt_num_kind:
494+
if isinstance(dtype1, WeakIntegralType):
495+
return dpt.dtype(ti.default_device_int_type(dev)), dtype2
496+
if isinstance(dtype1, WeakComplexType):
497+
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
498+
return dpt.complex64, dtype2
499+
return (
500+
_to_device_supported_dtype(dpt.complex128, dev),
501+
dtype2,
502+
)
503+
return _to_device_supported_dtype(dpt.float64, dev), dtype2
504+
else:
505+
return max_dtype, dtype2
506+
elif _is_weak_dtype(dtype2):
507+
max_dt_num_kind, max_dtype = max(
508+
[
509+
(_strong_dtype_num_kind(st_dtype), st_dtype),
510+
(_strong_dtype_num_kind(dtype1), dtype1),
511+
]
512+
)
513+
dt2_kind_num = _weak_type_num_kind(dtype2)
514+
if dt2_kind_num > max_dt_num_kind:
515+
if isinstance(dtype2, WeakIntegralType):
516+
return dtype1, dpt.dtype(ti.default_device_int_type(dev))
517+
if isinstance(dtype2, WeakComplexType):
518+
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
519+
return dtype1, dpt.complex64
520+
return (
521+
dtype1,
522+
_to_device_supported_dtype(dpt.complex128, dev),
523+
)
524+
return dtype1, _to_device_supported_dtype(dpt.float64, dev)
525+
else:
526+
return dtype1, max_dtype
527+
else:
528+
# both are strong dtypes
529+
# return unmodified
530+
return dtype1, dtype2
531+
532+
533+
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
534+
"Resolves one weak data type with one strong data type per NEP-0050"
535+
if _is_weak_dtype(st_dtype):
536+
raise ValueError
537+
if _is_weak_dtype(dtype):
538+
st_kind_num = _strong_dtype_num_kind(st_dtype)
539+
kind_num = _weak_type_num_kind(dtype)
540+
if kind_num > st_kind_num:
541+
if isinstance(dtype, WeakIntegralType):
542+
return dpt.dtype(ti.default_device_int_type(dev))
543+
if isinstance(dtype, WeakComplexType):
544+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
545+
return dpt.complex64
546+
return _to_device_supported_dtype(dpt.complex128, dev)
547+
return _to_device_supported_dtype(dpt.float64, dev)
548+
else:
549+
return st_dtype
550+
else:
551+
return dtype
552+
553+
448554
class finfo_object:
449555
"""
450556
`numpy.finfo` subclass which returns Python floating-point scalars for
@@ -833,6 +939,8 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
833939
"_acceptance_fn_divide",
834940
"_acceptance_fn_negative",
835941
"_acceptance_fn_subtract",
942+
"_resolve_one_strong_one_weak_types",
943+
"_resolve_one_strong_two_weak_types",
836944
"_resolve_weak_types",
837945
"_resolve_weak_types_comparisons",
838946
"_weak_type_num_kind",

dpctl/tensor/_utility_functions.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@
2020
import dpctl.tensor._tensor_impl as ti
2121
import dpctl.tensor._tensor_reductions_impl as tri
2222
import dpctl.utils as du
23-
from dpctl.tensor._clip import (
24-
_resolve_one_strong_one_weak_types,
25-
_resolve_one_strong_two_weak_types,
26-
)
2723
from dpctl.tensor._elementwise_common import (
2824
_get_dtype,
2925
_get_queue_usm_type,
@@ -32,6 +28,10 @@
3228
)
3329

3430
from ._numpy_helper import normalize_axis_index, normalize_axis_tuple
31+
from ._type_utils import (
32+
_resolve_one_strong_one_weak_types,
33+
_resolve_one_strong_two_weak_types,
34+
)
3535

3636

3737
def _boolean_reduction(x, axis, keepdims, func):
@@ -159,6 +159,8 @@ def any(x, /, *, axis=None, keepdims=False):
159159

160160

161161
def _validate_diff_shape(sh1, sh2, axis):
162+
"""Utility for validating that two shapes `sh1` and `sh2`
163+
are possible to concatenate along `axis`."""
162164
if not sh2:
163165
# scalars will always be accepted
164166
return True
@@ -173,6 +175,9 @@ def _validate_diff_shape(sh1, sh2, axis):
173175

174176

175177
def _concat_diff_input(arr, axis, prepend, append):
178+
"""Concatenates `arr`, `prepend` and, `append` along `axis`,
179+
where `arr` is an array and `prepend` and `append` are
180+
any mixture of arrays and scalars."""
176181
if prepend is not None and append is not None:
177182
q1, x_usm_type = arr.sycl_queue, arr.usm_type
178183
q2, prepend_usm_type = _get_queue_usm_type(prepend)

0 commit comments

Comments
 (0)