1717import dpctl
1818import dpctl .tensor as dpt
1919import dpctl .tensor ._tensor_impl as ti
20- from dpctl .tensor ._manipulation_functions import _broadcast_shapes
20+ from dpctl .tensor ._elementwise_common import (
21+ _get_dtype ,
22+ _get_queue_usm_type ,
23+ _get_shape ,
24+ _validate_dtype ,
25+ )
26+ from dpctl .tensor ._manipulation_functions import _broadcast_shape_impl
2127from dpctl .utils import ExecutionPlacementError , SequentialOrderManager
2228
2329from ._copy_utils import _empty_like_orderK , _empty_like_triple_orderK
24- from ._type_utils import _all_data_types , _can_cast
30+ from ._type_utils import (
31+ WeakBooleanType ,
32+ WeakComplexType ,
33+ WeakFloatingType ,
34+ WeakIntegralType ,
35+ _all_data_types ,
36+ _can_cast ,
37+ _is_weak_dtype ,
38+ _strong_dtype_num_kind ,
39+ _to_device_supported_dtype ,
40+ _weak_type_num_kind ,
41+ )
42+
43+
44+ def _default_dtype_from_weak_type (dt , dev ):
45+ if isinstance (dt , WeakBooleanType ):
46+ return dpt .bool
47+ if isinstance (dt , WeakIntegralType ):
48+ return dpt .dtype (ti .default_device_int_type (dev ))
49+ if isinstance (dt , WeakFloatingType ):
50+ return dpt .dtype (ti .default_device_fp_type (dev ))
51+ if isinstance (dt , WeakComplexType ):
52+ return dpt .dtype (ti .default_device_complex_type (dev ))
53+
54+
55+ def _resolve_two_weak_types (o1_dtype , o2_dtype , dev ):
56+ "Resolves two weak data types per NEP-0050"
57+ if _is_weak_dtype (o1_dtype ):
58+ if _is_weak_dtype (o2_dtype ):
59+ return _default_dtype_from_weak_type (
60+ o1_dtype , dev
61+ ), _default_dtype_from_weak_type (o2_dtype , dev )
62+ o1_kind_num = _weak_type_num_kind (o1_dtype )
63+ o2_kind_num = _strong_dtype_num_kind (o2_dtype )
64+ if o1_kind_num > o2_kind_num :
65+ if isinstance (o1_dtype , WeakIntegralType ):
66+ return dpt .dtype (ti .default_device_int_type (dev )), o2_dtype
67+ if isinstance (o1_dtype , WeakComplexType ):
68+ if o2_dtype is dpt .float16 or o2_dtype is dpt .float32 :
69+ return dpt .complex64 , o2_dtype
70+ return (
71+ _to_device_supported_dtype (dpt .complex128 , dev ),
72+ o2_dtype ,
73+ )
74+ return _to_device_supported_dtype (dpt .float64 , dev ), o2_dtype
75+ else :
76+ return o2_dtype , o2_dtype
77+ elif _is_weak_dtype (o2_dtype ):
78+ o1_kind_num = _strong_dtype_num_kind (o1_dtype )
79+ o2_kind_num = _weak_type_num_kind (o2_dtype )
80+ if o2_kind_num > o1_kind_num :
81+ if isinstance (o2_dtype , WeakIntegralType ):
82+ return o1_dtype , dpt .dtype (ti .default_device_int_type (dev ))
83+ if isinstance (o2_dtype , WeakComplexType ):
84+ if o1_dtype is dpt .float16 or o1_dtype is dpt .float32 :
85+ return o1_dtype , dpt .complex64
86+ return o1_dtype , _to_device_supported_dtype (dpt .complex128 , dev )
87+ return (
88+ o1_dtype ,
89+ _to_device_supported_dtype (dpt .float64 , dev ),
90+ )
91+ else :
92+ return o1_dtype , o1_dtype
93+ else :
94+ return o1_dtype , o2_dtype
2595
2696
2797def _where_result_type (dt1 , dt2 , dev ):
@@ -51,16 +121,17 @@ def where(condition, x1, x2, /, *, order="K", out=None):
51121 and otherwise yields from ``x2``.
52122 Must be compatible with ``x1`` and ``x2`` according
53123 to broadcasting rules.
54- x1 (usm_ndarray): Array from which values are chosen when
55- ``condition`` is ``True``.
124+ x1 (Union[ usm_ndarray, bool, int, float, complex]):
125+ Array from which values are chosen when ``condition`` is ``True``.
56126 Must be compatible with ``condition`` and ``x2`` according
57127 to broadcasting rules.
58- x2 (usm_ndarray): Array from which values are chosen when
59- ``condition`` is not ``True``.
128+ x2 (Union[usm_ndarray, bool, int, float, complex]):
129+ Array from which values are chosen when ``condition`` is not
130+ ``True``.
60131 Must be compatible with ``condition`` and ``x2`` according
61132 to broadcasting rules.
62133 order (``"K"``, ``"C"``, ``"F"``, ``"A"``, optional):
63- Memory layout of the new output arra ,
134+ Memory layout of the new output array ,
64135 if parameter ``out`` is ``None``.
65136 Default: ``"K"``.
66137 out (Optional[usm_ndarray]):
@@ -81,36 +152,90 @@ def where(condition, x1, x2, /, *, order="K", out=None):
81152 raise TypeError (
82153 "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (condition )} "
83154 )
84- if not isinstance (x1 , dpt .usm_ndarray ):
85- raise TypeError (
86- "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (x1 )} "
155+ if order not in ["K" , "C" , "F" , "A" ]:
156+ order = "K"
157+ q1 , condition_usm_type = condition .sycl_queue , condition .usm_type
158+ q2 , x1_usm_type = _get_queue_usm_type (x1 )
159+ q3 , x2_usm_type = _get_queue_usm_type (x2 )
160+ if q2 is None and q3 is None :
161+ exec_q = q1
162+ out_usm_type = condition_usm_type
163+ elif q3 is None :
164+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
165+ if exec_q is None :
166+ raise ExecutionPlacementError (
167+ "Execution placement can not be unambiguously inferred "
168+ "from input arguments."
169+ )
170+ out_usm_type = dpctl .utils .get_coerced_usm_type (
171+ (
172+ condition_usm_type ,
173+ x1_usm_type ,
174+ )
87175 )
88- if not isinstance (x2 , dpt .usm_ndarray ):
176+ elif q2 is None :
177+ exec_q = dpctl .utils .get_execution_queue ((q1 , q3 ))
178+ if exec_q is None :
179+ raise ExecutionPlacementError (
180+ "Execution placement can not be unambiguously inferred "
181+ "from input arguments."
182+ )
183+ out_usm_type = dpctl .utils .get_coerced_usm_type (
184+ (
185+ condition_usm_type ,
186+ x2_usm_type ,
187+ )
188+ )
189+ else :
190+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 , q3 ))
191+ if exec_q is None :
192+ raise ExecutionPlacementError (
193+ "Execution placement can not be unambiguously inferred "
194+ "from input arguments."
195+ )
196+ out_usm_type = dpctl .utils .get_coerced_usm_type (
197+ (
198+ condition_usm_type ,
199+ x1_usm_type ,
200+ x2_usm_type ,
201+ )
202+ )
203+ dpctl .utils .validate_usm_type (out_usm_type , allow_none = False )
204+ condition_shape = condition .shape
205+ x1_shape = _get_shape (x1 )
206+ x2_shape = _get_shape (x2 )
207+ if not all (
208+ isinstance (s , (tuple , list ))
209+ for s in (
210+ x1_shape ,
211+ x2_shape ,
212+ )
213+ ):
89214 raise TypeError (
90- "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (x2 )} "
215+ "Shape of arguments can not be inferred. "
216+ "Arguments are expected to be "
217+ "lists, tuples, or both"
91218 )
92- if order not in [ "K" , "C" , "F" , "A" ] :
93- order = "K"
94- exec_q = dpctl . utils . get_execution_queue (
95- (
96- condition . sycl_queue ,
97- x1 . sycl_queue ,
98- x2 . sycl_queue ,
219+ try :
220+ res_shape = _broadcast_shape_impl (
221+ [
222+ condition_shape ,
223+ x1_shape ,
224+ x2_shape ,
225+ ]
99226 )
100- )
101- if exec_q is None :
102- raise dpctl .utils .ExecutionPlacementError
103- out_usm_type = dpctl .utils .get_coerced_usm_type (
104- (
105- condition .usm_type ,
106- x1 .usm_type ,
107- x2 .usm_type ,
227+ except ValueError :
228+ raise ValueError (
229+ "operands could not be broadcast together with shapes "
230+ f"{ condition_shape } , { x1_shape } , and { x2_shape } "
108231 )
109- )
110-
111- x1_dtype = x1 .dtype
112- x2_dtype = x2 .dtype
113- out_dtype = _where_result_type (x1_dtype , x2_dtype , exec_q .sycl_device )
232+ sycl_dev = exec_q .sycl_device
233+ x1_dtype = _get_dtype (x1 , sycl_dev )
234+ x2_dtype = _get_dtype (x2 , sycl_dev )
235+ if not all (_validate_dtype (o ) for o in (x1_dtype , x2_dtype )):
236+ raise ValueError ("Operands have unsupported data types" )
237+ x1_dtype , x2_dtype = _resolve_two_weak_types (x1_dtype , x2_dtype , sycl_dev )
238+ out_dtype = _where_result_type (x1_dtype , x2_dtype , sycl_dev )
114239 if out_dtype is None :
115240 raise TypeError (
116241 "function 'where' does not support input "
@@ -119,8 +244,6 @@ def where(condition, x1, x2, /, *, order="K", out=None):
119244 "to any supported types according to the casting rule ''safe''."
120245 )
121246
122- res_shape = _broadcast_shapes (condition , x1 , x2 )
123-
124247 orig_out = out
125248 if out is not None :
126249 if not isinstance (out , dpt .usm_ndarray ):
@@ -149,16 +272,25 @@ def where(condition, x1, x2, /, *, order="K", out=None):
149272 "Input and output allocation queues are not compatible"
150273 )
151274
152- if ti ._array_overlap (condition , out ):
153- if not ti ._same_logical_tensors (condition , out ):
154- out = dpt .empty_like (out )
275+ if ti ._array_overlap (condition , out ) and not ti ._same_logical_tensors (
276+ condition , out
277+ ):
278+ out = dpt .empty_like (out )
155279
156- if ti ._array_overlap (x1 , out ):
157- if not ti ._same_logical_tensors (x1 , out ):
280+ if isinstance (x1 , dpt .usm_ndarray ):
281+ if (
282+ ti ._array_overlap (x1 , out )
283+ and not ti ._same_logical_tensors (x1 , out )
284+ and x1_dtype == out_dtype
285+ ):
158286 out = dpt .empty_like (out )
159287
160- if ti ._array_overlap (x2 , out ):
161- if not ti ._same_logical_tensors (x2 , out ):
288+ if isinstance (x2 , dpt .usm_ndarray ):
289+ if (
290+ ti ._array_overlap (x2 , out )
291+ and not ti ._same_logical_tensors (x2 , out )
292+ and x2_dtype == out_dtype
293+ ):
162294 out = dpt .empty_like (out )
163295
164296 if order == "A" :
@@ -174,6 +306,10 @@ def where(condition, x1, x2, /, *, order="K", out=None):
174306 )
175307 else "C"
176308 )
309+ if not isinstance (x1 , dpt .usm_ndarray ):
310+ x1 = dpt .asarray (x1 , dtype = x1_dtype , sycl_queue = exec_q )
311+ if not isinstance (x2 , dpt .usm_ndarray ):
312+ x2 = dpt .asarray (x2 , dtype = x2_dtype , sycl_queue = exec_q )
177313
178314 if condition .size == 0 :
179315 if out is not None :
@@ -236,9 +372,12 @@ def where(condition, x1, x2, /, *, order="K", out=None):
236372 sycl_queue = exec_q ,
237373 )
238374
239- condition = dpt .broadcast_to (condition , res_shape )
240- x1 = dpt .broadcast_to (x1 , res_shape )
241- x2 = dpt .broadcast_to (x2 , res_shape )
375+ if condition_shape != res_shape :
376+ condition = dpt .broadcast_to (condition , res_shape )
377+ if x1_shape != res_shape :
378+ x1 = dpt .broadcast_to (x1 , res_shape )
379+ if x2_shape != res_shape :
380+ x2 = dpt .broadcast_to (x2 , res_shape )
242381
243382 dep_evs = _manager .submitted_events
244383 hev , where_ev = ti ._where (
0 commit comments