|
30 | 30 | _validate_dtype,
|
31 | 31 | )
|
32 | 32 | from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
|
33 |
| -from dpctl.tensor._type_utils import _can_cast, _to_device_supported_dtype |
| 33 | +from dpctl.tensor._type_utils import _can_cast |
34 | 34 | from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
|
35 | 35 |
|
36 | 36 | from ._type_utils import (
|
37 |
| - WeakComplexType, |
38 |
| - WeakIntegralType, |
39 |
| - _is_weak_dtype, |
40 |
| - _strong_dtype_num_kind, |
41 |
| - _weak_type_num_kind, |
| 37 | + _resolve_one_strong_one_weak_types, |
| 38 | + _resolve_one_strong_two_weak_types, |
42 | 39 | )
|
43 | 40 |
|
44 | 41 |
|
45 |
| -def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev): |
46 |
| - "Resolves weak data types per NEP-0050," |
47 |
| - "where the second and third arguments are" |
48 |
| - "permitted to be weak types" |
49 |
| - if _is_weak_dtype(st_dtype): |
50 |
| - raise ValueError |
51 |
| - if _is_weak_dtype(dtype1): |
52 |
| - if _is_weak_dtype(dtype2): |
53 |
| - kind_num1 = _weak_type_num_kind(dtype1) |
54 |
| - kind_num2 = _weak_type_num_kind(dtype2) |
55 |
| - st_kind_num = _strong_dtype_num_kind(st_dtype) |
56 |
| - |
57 |
| - if kind_num1 > st_kind_num: |
58 |
| - if isinstance(dtype1, WeakIntegralType): |
59 |
| - ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev)) |
60 |
| - elif isinstance(dtype1, WeakComplexType): |
61 |
| - if st_dtype is dpt.float16 or st_dtype is dpt.float32: |
62 |
| - ret_dtype1 = dpt.complex64 |
63 |
| - ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev) |
64 |
| - else: |
65 |
| - ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev) |
66 |
| - else: |
67 |
| - ret_dtype1 = st_dtype |
68 |
| - |
69 |
| - if kind_num2 > st_kind_num: |
70 |
| - if isinstance(dtype2, WeakIntegralType): |
71 |
| - ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev)) |
72 |
| - elif isinstance(dtype2, WeakComplexType): |
73 |
| - if st_dtype is dpt.float16 or st_dtype is dpt.float32: |
74 |
| - ret_dtype2 = dpt.complex64 |
75 |
| - ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev) |
76 |
| - else: |
77 |
| - ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev) |
78 |
| - else: |
79 |
| - ret_dtype2 = st_dtype |
80 |
| - |
81 |
| - return ret_dtype1, ret_dtype2 |
82 |
| - |
83 |
| - max_dt_num_kind, max_dtype = max( |
84 |
| - [ |
85 |
| - (_strong_dtype_num_kind(st_dtype), st_dtype), |
86 |
| - (_strong_dtype_num_kind(dtype2), dtype2), |
87 |
| - ] |
88 |
| - ) |
89 |
| - dt1_kind_num = _weak_type_num_kind(dtype1) |
90 |
| - if dt1_kind_num > max_dt_num_kind: |
91 |
| - if isinstance(dtype1, WeakIntegralType): |
92 |
| - return dpt.dtype(ti.default_device_int_type(dev)), dtype2 |
93 |
| - if isinstance(dtype1, WeakComplexType): |
94 |
| - if max_dtype is dpt.float16 or max_dtype is dpt.float32: |
95 |
| - return dpt.complex64, dtype2 |
96 |
| - return ( |
97 |
| - _to_device_supported_dtype(dpt.complex128, dev), |
98 |
| - dtype2, |
99 |
| - ) |
100 |
| - return _to_device_supported_dtype(dpt.float64, dev), dtype2 |
101 |
| - else: |
102 |
| - return max_dtype, dtype2 |
103 |
| - elif _is_weak_dtype(dtype2): |
104 |
| - max_dt_num_kind, max_dtype = max( |
105 |
| - [ |
106 |
| - (_strong_dtype_num_kind(st_dtype), st_dtype), |
107 |
| - (_strong_dtype_num_kind(dtype1), dtype1), |
108 |
| - ] |
109 |
| - ) |
110 |
| - dt2_kind_num = _weak_type_num_kind(dtype2) |
111 |
| - if dt2_kind_num > max_dt_num_kind: |
112 |
| - if isinstance(dtype2, WeakIntegralType): |
113 |
| - return dtype1, dpt.dtype(ti.default_device_int_type(dev)) |
114 |
| - if isinstance(dtype2, WeakComplexType): |
115 |
| - if max_dtype is dpt.float16 or max_dtype is dpt.float32: |
116 |
| - return dtype1, dpt.complex64 |
117 |
| - return ( |
118 |
| - dtype1, |
119 |
| - _to_device_supported_dtype(dpt.complex128, dev), |
120 |
| - ) |
121 |
| - return dtype1, _to_device_supported_dtype(dpt.float64, dev) |
122 |
| - else: |
123 |
| - return dtype1, max_dtype |
124 |
| - else: |
125 |
| - # both are strong dtypes |
126 |
| - # return unmodified |
127 |
| - return dtype1, dtype2 |
128 |
| - |
129 |
| - |
130 |
| -def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev): |
131 |
| - "Resolves one weak data type with one strong data type per NEP-0050" |
132 |
| - if _is_weak_dtype(st_dtype): |
133 |
| - raise ValueError |
134 |
| - if _is_weak_dtype(dtype): |
135 |
| - st_kind_num = _strong_dtype_num_kind(st_dtype) |
136 |
| - kind_num = _weak_type_num_kind(dtype) |
137 |
| - if kind_num > st_kind_num: |
138 |
| - if isinstance(dtype, WeakIntegralType): |
139 |
| - return dpt.dtype(ti.default_device_int_type(dev)) |
140 |
| - if isinstance(dtype, WeakComplexType): |
141 |
| - if st_dtype is dpt.float16 or st_dtype is dpt.float32: |
142 |
| - return dpt.complex64 |
143 |
| - return _to_device_supported_dtype(dpt.complex128, dev) |
144 |
| - return _to_device_supported_dtype(dpt.float64, dev) |
145 |
| - else: |
146 |
| - return st_dtype |
147 |
| - else: |
148 |
| - return dtype |
149 |
| - |
150 |
| - |
151 | 42 | def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
|
152 | 43 | "Checks if both types `arg1_dtype` and `arg2_dtype` can be"
|
153 | 44 | "cast to `res_dtype` according to the rule `safe`"
|
|
0 commit comments