@@ -147,6 +147,51 @@ def _build_choices_list(choices):
147147 return choices , queues , usm_types
148148
149149
150+ def _choose_run (inds , chcs , q , usm_type , out = None , mode = 0 ):
151+ # arg validation, broadcasting, type coercion assumed done by caller
152+ if out is not None :
153+ dpnp .check_supported_arrays_type (out )
154+ out = dpnp .get_usm_ndarray (out )
155+
156+ if not out .flags .writable :
157+ raise ValueError ("provided `out` array is read-only" )
158+
159+ if out .shape != inds .shape :
160+ raise ValueError (
161+ "The shape of input and output arrays are inconsistent. "
162+ f"Expected output shape is { inds .shape } , got { out .shape } "
163+ )
164+
165+ if chcs [0 ].dtype != out .dtype :
166+ raise ValueError (
167+ f"Output array of type { chcs [0 ].dtype } is needed, "
168+ f"got { out .dtype } "
169+ )
170+
171+ if dpu .get_execution_queue ((q , out .sycl_queue )) is None :
172+ raise dpu .ExecutionPlacementError (
173+ "Input and output allocation queues are not compatible"
174+ )
175+
176+ if ti ._array_overlap (inds , out ) or any (
177+ ti ._array_overlap (out , chc ) for chc in chcs
178+ ):
179+ # Allocate a temporary buffer to avoid memory overlapping.
180+ out = dpt .empty_like (out )
181+ else :
182+ out = dpt .empty (
183+ inds .shape , dtype = chcs [0 ].dtype , usm_type = usm_type , sycl_queue = q
184+ )
185+
186+ _manager = dpu .SequentialOrderManager [q ]
187+ dep_evs = _manager .submitted_events
188+
189+ h_ev , choose_ev = indexing_ext ._choose (inds , chcs , out , mode , q , dep_evs )
190+ _manager .add_event_pair (h_ev , choose_ev )
191+
192+ return out
193+
194+
150195def choose (x , choices , out = None , mode = "wrap" ):
151196 """
152197 Construct an array from an index array and a set of arrays to choose from.
@@ -218,59 +263,10 @@ def choose(x, choices, out=None, mode="wrap"):
218263 arrs_broadcast = dpt .broadcast_arrays (inds , * choices )
219264 inds = arrs_broadcast [0 ]
220265 choices = tuple (arrs_broadcast [1 :])
221- res_sh = inds .shape
222266
223- orig_out = out
224- if out is not None :
225- dpnp .check_supported_arrays_type (out )
226- out = dpnp .get_usm_ndarray (out )
227-
228- if not out .flags .writable :
229- raise ValueError ("provided `out` array is read-only" )
230-
231- if out .shape != res_sh :
232- raise ValueError (
233- "The shape of input and output arrays are inconsistent. "
234- f"Expected output shape is { res_sh } , got { out .shape } "
235- )
236-
237- if res_dt != out .dtype :
238- raise ValueError (
239- f"Output array of type { res_dt } is needed, " f"got { out .dtype } "
240- )
241-
242- if dpu .get_execution_queue ((x .sycl_queue , out .sycl_queue )) is None :
243- raise dpu .ExecutionPlacementError (
244- "Input and output allocation queues are not compatible"
245- )
246-
247- if ti ._array_overlap (x , out ) or any (
248- ti ._array_overlap (out , chc ) for chc in choices
249- ):
250- # Allocate a temporary buffer to avoid memory overlapping.
251- out = dpt .empty_like (out )
252- else :
253- out = dpt .empty (
254- res_sh , dtype = res_dt , usm_type = res_usm_type , sycl_queue = exec_q
255- )
267+ res = _choose_run (inds , choices , exec_q , res_usm_type , out = out , mode = mode )
256268
257- _manager = dpu .SequentialOrderManager [exec_q ]
258- dep_evs = _manager .submitted_events
259-
260- h_ev , choose_ev = indexing_ext ._choose (
261- inds , choices , out , mode , exec_q , dep_evs
262- )
263- _manager .add_event_pair (h_ev , choose_ev )
264-
265- if not (orig_out is None or orig_out is out ):
266- # Copy the out data from temporary buffer to original memory
267- ht_copy_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
268- src = out , dst = orig_out , sycl_queue = exec_q , depends = [choose_ev ]
269- )
270- _manager .add_event_pair (ht_copy_ev , cpy_ev )
271- out = orig_out
272-
273- return dpnp .get_result_array (out )
269+ return dpnp .get_result_array (res , out = out )
274270
275271
276272def _take_index (x , inds , axis , q , usm_type , out = None , mode = 0 ):
0 commit comments