Skip to content

Commit f975dbb

Browse files
committed
Move _resolve_one_strong_one_weak_types and _resolve_one_strong_two_weak_types to _type_utils
1 parent b42fb07 commit f975dbb

File tree

3 files changed

+121
-116
lines changed

3 files changed

+121
-116
lines changed

dpctl/tensor/_clip.py

Lines changed: 3 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -30,124 +30,15 @@
3030
_validate_dtype,
3131
)
3232
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
33-
from dpctl.tensor._type_utils import _can_cast, _to_device_supported_dtype
33+
from dpctl.tensor._type_utils import _can_cast
3434
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
3535

3636
from ._type_utils import (
37-
WeakComplexType,
38-
WeakIntegralType,
39-
_is_weak_dtype,
40-
_strong_dtype_num_kind,
41-
_weak_type_num_kind,
37+
_resolve_one_strong_one_weak_types,
38+
_resolve_one_strong_two_weak_types,
4239
)
4340

4441

45-
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
46-
"Resolves weak data types per NEP-0050,"
47-
"where the second and third arguments are"
48-
"permitted to be weak types"
49-
if _is_weak_dtype(st_dtype):
50-
raise ValueError
51-
if _is_weak_dtype(dtype1):
52-
if _is_weak_dtype(dtype2):
53-
kind_num1 = _weak_type_num_kind(dtype1)
54-
kind_num2 = _weak_type_num_kind(dtype2)
55-
st_kind_num = _strong_dtype_num_kind(st_dtype)
56-
57-
if kind_num1 > st_kind_num:
58-
if isinstance(dtype1, WeakIntegralType):
59-
ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
60-
elif isinstance(dtype1, WeakComplexType):
61-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
62-
ret_dtype1 = dpt.complex64
63-
ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
64-
else:
65-
ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
66-
else:
67-
ret_dtype1 = st_dtype
68-
69-
if kind_num2 > st_kind_num:
70-
if isinstance(dtype2, WeakIntegralType):
71-
ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
72-
elif isinstance(dtype2, WeakComplexType):
73-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
74-
ret_dtype2 = dpt.complex64
75-
ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
76-
else:
77-
ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
78-
else:
79-
ret_dtype2 = st_dtype
80-
81-
return ret_dtype1, ret_dtype2
82-
83-
max_dt_num_kind, max_dtype = max(
84-
[
85-
(_strong_dtype_num_kind(st_dtype), st_dtype),
86-
(_strong_dtype_num_kind(dtype2), dtype2),
87-
]
88-
)
89-
dt1_kind_num = _weak_type_num_kind(dtype1)
90-
if dt1_kind_num > max_dt_num_kind:
91-
if isinstance(dtype1, WeakIntegralType):
92-
return dpt.dtype(ti.default_device_int_type(dev)), dtype2
93-
if isinstance(dtype1, WeakComplexType):
94-
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
95-
return dpt.complex64, dtype2
96-
return (
97-
_to_device_supported_dtype(dpt.complex128, dev),
98-
dtype2,
99-
)
100-
return _to_device_supported_dtype(dpt.float64, dev), dtype2
101-
else:
102-
return max_dtype, dtype2
103-
elif _is_weak_dtype(dtype2):
104-
max_dt_num_kind, max_dtype = max(
105-
[
106-
(_strong_dtype_num_kind(st_dtype), st_dtype),
107-
(_strong_dtype_num_kind(dtype1), dtype1),
108-
]
109-
)
110-
dt2_kind_num = _weak_type_num_kind(dtype2)
111-
if dt2_kind_num > max_dt_num_kind:
112-
if isinstance(dtype2, WeakIntegralType):
113-
return dtype1, dpt.dtype(ti.default_device_int_type(dev))
114-
if isinstance(dtype2, WeakComplexType):
115-
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
116-
return dtype1, dpt.complex64
117-
return (
118-
dtype1,
119-
_to_device_supported_dtype(dpt.complex128, dev),
120-
)
121-
return dtype1, _to_device_supported_dtype(dpt.float64, dev)
122-
else:
123-
return dtype1, max_dtype
124-
else:
125-
# both are strong dtypes
126-
# return unmodified
127-
return dtype1, dtype2
128-
129-
130-
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
131-
"Resolves one weak data type with one strong data type per NEP-0050"
132-
if _is_weak_dtype(st_dtype):
133-
raise ValueError
134-
if _is_weak_dtype(dtype):
135-
st_kind_num = _strong_dtype_num_kind(st_dtype)
136-
kind_num = _weak_type_num_kind(dtype)
137-
if kind_num > st_kind_num:
138-
if isinstance(dtype, WeakIntegralType):
139-
return dpt.dtype(ti.default_device_int_type(dev))
140-
if isinstance(dtype, WeakComplexType):
141-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
142-
return dpt.complex64
143-
return _to_device_supported_dtype(dpt.complex128, dev)
144-
return _to_device_supported_dtype(dpt.float64, dev)
145-
else:
146-
return st_dtype
147-
else:
148-
return dtype
149-
150-
15142
def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
15243
"Checks if both types `arg1_dtype` and `arg2_dtype` can be"
15344
"cast to `res_dtype` according to the rule `safe`"

dpctl/tensor/_type_utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,112 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
445445
return o1_dtype, o2_dtype
446446

447447

448+
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
449+
"Resolves weak data types per NEP-0050,"
450+
"where the second and third arguments are"
451+
"permitted to be weak types"
452+
if _is_weak_dtype(st_dtype):
453+
raise ValueError
454+
if _is_weak_dtype(dtype1):
455+
if _is_weak_dtype(dtype2):
456+
kind_num1 = _weak_type_num_kind(dtype1)
457+
kind_num2 = _weak_type_num_kind(dtype2)
458+
st_kind_num = _strong_dtype_num_kind(st_dtype)
459+
460+
if kind_num1 > st_kind_num:
461+
if isinstance(dtype1, WeakIntegralType):
462+
ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
463+
elif isinstance(dtype1, WeakComplexType):
464+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
465+
ret_dtype1 = dpt.complex64
466+
ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
467+
else:
468+
ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
469+
else:
470+
ret_dtype1 = st_dtype
471+
472+
if kind_num2 > st_kind_num:
473+
if isinstance(dtype2, WeakIntegralType):
474+
ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
475+
elif isinstance(dtype2, WeakComplexType):
476+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
477+
ret_dtype2 = dpt.complex64
478+
ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
479+
else:
480+
ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
481+
else:
482+
ret_dtype2 = st_dtype
483+
484+
return ret_dtype1, ret_dtype2
485+
486+
max_dt_num_kind, max_dtype = max(
487+
[
488+
(_strong_dtype_num_kind(st_dtype), st_dtype),
489+
(_strong_dtype_num_kind(dtype2), dtype2),
490+
]
491+
)
492+
dt1_kind_num = _weak_type_num_kind(dtype1)
493+
if dt1_kind_num > max_dt_num_kind:
494+
if isinstance(dtype1, WeakIntegralType):
495+
return dpt.dtype(ti.default_device_int_type(dev)), dtype2
496+
if isinstance(dtype1, WeakComplexType):
497+
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
498+
return dpt.complex64, dtype2
499+
return (
500+
_to_device_supported_dtype(dpt.complex128, dev),
501+
dtype2,
502+
)
503+
return _to_device_supported_dtype(dpt.float64, dev), dtype2
504+
else:
505+
return max_dtype, dtype2
506+
elif _is_weak_dtype(dtype2):
507+
max_dt_num_kind, max_dtype = max(
508+
[
509+
(_strong_dtype_num_kind(st_dtype), st_dtype),
510+
(_strong_dtype_num_kind(dtype1), dtype1),
511+
]
512+
)
513+
dt2_kind_num = _weak_type_num_kind(dtype2)
514+
if dt2_kind_num > max_dt_num_kind:
515+
if isinstance(dtype2, WeakIntegralType):
516+
return dtype1, dpt.dtype(ti.default_device_int_type(dev))
517+
if isinstance(dtype2, WeakComplexType):
518+
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
519+
return dtype1, dpt.complex64
520+
return (
521+
dtype1,
522+
_to_device_supported_dtype(dpt.complex128, dev),
523+
)
524+
return dtype1, _to_device_supported_dtype(dpt.float64, dev)
525+
else:
526+
return dtype1, max_dtype
527+
else:
528+
# both are strong dtypes
529+
# return unmodified
530+
return dtype1, dtype2
531+
532+
533+
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
534+
"Resolves one weak data type with one strong data type per NEP-0050"
535+
if _is_weak_dtype(st_dtype):
536+
raise ValueError
537+
if _is_weak_dtype(dtype):
538+
st_kind_num = _strong_dtype_num_kind(st_dtype)
539+
kind_num = _weak_type_num_kind(dtype)
540+
if kind_num > st_kind_num:
541+
if isinstance(dtype, WeakIntegralType):
542+
return dpt.dtype(ti.default_device_int_type(dev))
543+
if isinstance(dtype, WeakComplexType):
544+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
545+
return dpt.complex64
546+
return _to_device_supported_dtype(dpt.complex128, dev)
547+
return _to_device_supported_dtype(dpt.float64, dev)
548+
else:
549+
return st_dtype
550+
else:
551+
return dtype
552+
553+
448554
class finfo_object:
449555
"""
450556
`numpy.finfo` subclass which returns Python floating-point scalars for
@@ -833,6 +939,8 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
833939
"_acceptance_fn_divide",
834940
"_acceptance_fn_negative",
835941
"_acceptance_fn_subtract",
942+
"_resolve_one_strong_one_weak_types",
943+
"_resolve_one_strong_two_weak_types",
836944
"_resolve_weak_types",
837945
"_resolve_weak_types_comparisons",
838946
"_weak_type_num_kind",

dpctl/tensor/_utility_functions.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,18 @@
2323
import dpctl.tensor._tensor_impl as ti
2424
import dpctl.tensor._tensor_reductions_impl as tri
2525
import dpctl.utils as du
26-
from dpctl.tensor._clip import (
27-
_resolve_one_strong_one_weak_types,
28-
_resolve_one_strong_two_weak_types,
29-
)
3026
from dpctl.tensor._elementwise_common import (
3127
_get_dtype,
3228
_get_queue_usm_type,
3329
_get_shape,
3430
_validate_dtype,
3531
)
3632

33+
from ._type_utils import (
34+
_resolve_one_strong_one_weak_types,
35+
_resolve_one_strong_two_weak_types,
36+
)
37+
3738

3839
def _boolean_reduction(x, axis, keepdims, func):
3940
if not isinstance(x, dpt.usm_ndarray):
@@ -160,6 +161,8 @@ def any(x, /, *, axis=None, keepdims=False):
160161

161162

162163
def _validate_diff_shape(sh1, sh2, axis):
164+
"""Utility for validating that two shapes `sh1` and `sh2`
165+
are possible to concatenate along `axis`."""
163166
if not sh2:
164167
# scalars will always be accepted
165168
return True
@@ -174,6 +177,9 @@ def _validate_diff_shape(sh1, sh2, axis):
174177

175178

176179
def _concat_diff_input(arr, axis, prepend, append):
180+
"""Concatenates `arr`, `prepend` and, `append` along `axis`,
181+
where `arr` is an array and `prepend` and `append` are
182+
any mixture of arrays and scalars."""
177183
if prepend is not None and append is not None:
178184
q1, x_usm_type = arr.sycl_queue, arr.usm_type
179185
q2, prepend_usm_type = _get_queue_usm_type(prepend)

0 commit comments

Comments
 (0)