@@ -146,6 +146,51 @@ def _build_choices_list(choices):
146146 return choices , queues , usm_types
147147
148148
149+ def _choose_run (inds , chcs , q , usm_type , out = None , mode = 0 ):
150+ # arg validation, broadcasting, type coercion assumed done by caller
151+ if out is not None :
152+ dpnp .check_supported_arrays_type (out )
153+ out = dpnp .get_usm_ndarray (out )
154+
155+ if not out .flags .writable :
156+ raise ValueError ("provided `out` array is read-only" )
157+
158+ if out .shape != inds .shape :
159+ raise ValueError (
160+ "The shape of input and output arrays are inconsistent. "
161+ f"Expected output shape is { inds .shape } , got { out .shape } "
162+ )
163+
164+ if chcs [0 ].dtype != out .dtype :
165+ raise ValueError (
166+ f"Output array of type { chcs [0 ].dtype } is needed, "
167+ f"got { out .dtype } "
168+ )
169+
170+ if dpu .get_execution_queue ((q , out .sycl_queue )) is None :
171+ raise dpu .ExecutionPlacementError (
172+ "Input and output allocation queues are not compatible"
173+ )
174+
175+ if ti ._array_overlap (inds , out ) or any (
176+ ti ._array_overlap (out , chc ) for chc in chcs
177+ ):
178+ # Allocate a temporary buffer to avoid memory overlapping.
179+ out = dpt .empty_like (out )
180+ else :
181+ out = dpt .empty (
182+ inds .shape , dtype = chcs [0 ].dtype , usm_type = usm_type , sycl_queue = q
183+ )
184+
185+ _manager = dpu .SequentialOrderManager [q ]
186+ dep_evs = _manager .submitted_events
187+
188+ h_ev , choose_ev = indexing_ext ._choose (inds , chcs , out , mode , q , dep_evs )
189+ _manager .add_event_pair (h_ev , choose_ev )
190+
191+ return out
192+
193+
149194def choose (x , choices , out = None , mode = "wrap" ):
150195 """
151196 Construct an array from an index array and a set of arrays to choose from.
@@ -217,59 +262,10 @@ def choose(x, choices, out=None, mode="wrap"):
217262 arrs_broadcast = dpt .broadcast_arrays (inds , * choices )
218263 inds = arrs_broadcast [0 ]
219264 choices = tuple (arrs_broadcast [1 :])
220- res_sh = inds .shape
221265
222- orig_out = out
223- if out is not None :
224- dpnp .check_supported_arrays_type (out )
225- out = dpnp .get_usm_ndarray (out )
226-
227- if not out .flags .writable :
228- raise ValueError ("provided `out` array is read-only" )
229-
230- if out .shape != res_sh :
231- raise ValueError (
232- "The shape of input and output arrays are inconsistent. "
233- f"Expected output shape is { res_sh } , got { out .shape } "
234- )
235-
236- if res_dt != out .dtype :
237- raise ValueError (
238- f"Output array of type { res_dt } is needed, " f"got { out .dtype } "
239- )
240-
241- if dpu .get_execution_queue ((x .sycl_queue , out .sycl_queue )) is None :
242- raise dpu .ExecutionPlacementError (
243- "Input and output allocation queues are not compatible"
244- )
245-
246- if ti ._array_overlap (x , out ) or any (
247- ti ._array_overlap (out , chc ) for chc in choices
248- ):
249- # Allocate a temporary buffer to avoid memory overlapping.
250- out = dpt .empty_like (out )
251- else :
252- out = dpt .empty (
253- res_sh , dtype = res_dt , usm_type = res_usm_type , sycl_queue = exec_q
254- )
266+ res = _choose_run (inds , choices , exec_q , res_usm_type , out = out , mode = mode )
255267
256- _manager = dpu .SequentialOrderManager [exec_q ]
257- dep_evs = _manager .submitted_events
258-
259- h_ev , choose_ev = indexing_ext ._choose (
260- inds , choices , out , mode , exec_q , dep_evs
261- )
262- _manager .add_event_pair (h_ev , choose_ev )
263-
264- if not (orig_out is None or orig_out is out ):
265- # Copy the out data from temporary buffer to original memory
266- ht_copy_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
267- src = out , dst = orig_out , sycl_queue = exec_q , depends = [choose_ev ]
268- )
269- _manager .add_event_pair (ht_copy_ev , cpy_ev )
270- out = orig_out
271-
272- return dpnp .get_result_array (out )
268+ return dpnp .get_result_array (res , out = out )
273269
274270
275271def _take_index (x , inds , axis , q , usm_type , out = None , mode = 0 ):
0 commit comments