Skip to content

Commit a73e22b

Browse files
committed
Factor out keyword validation and kernel run out of choose
1 parent 7d0138e commit a73e22b

File tree

1 file changed

+47
-51
lines changed

1 file changed

+47
-51
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,51 @@ def _build_choices_list(choices):
147147
return choices, queues, usm_types
148148

149149

150+
def _choose_run(inds, chcs, q, usm_type, out=None, mode=0):
151+
# arg validation, broadcasting, type coercion assumed done by caller
152+
if out is not None:
153+
dpnp.check_supported_arrays_type(out)
154+
out = dpnp.get_usm_ndarray(out)
155+
156+
if not out.flags.writable:
157+
raise ValueError("provided `out` array is read-only")
158+
159+
if out.shape != inds.shape:
160+
raise ValueError(
161+
"The shape of input and output arrays are inconsistent. "
162+
f"Expected output shape is {inds.shape}, got {out.shape}"
163+
)
164+
165+
if chcs[0].dtype != out.dtype:
166+
raise ValueError(
167+
f"Output array of type {chcs[0].dtype} is needed, "
168+
f"got {out.dtype}"
169+
)
170+
171+
if dpu.get_execution_queue((q, out.sycl_queue)) is None:
172+
raise dpu.ExecutionPlacementError(
173+
"Input and output allocation queues are not compatible"
174+
)
175+
176+
if ti._array_overlap(inds, out) or any(
177+
ti._array_overlap(out, chc) for chc in chcs
178+
):
179+
# Allocate a temporary buffer to avoid memory overlapping.
180+
out = dpt.empty_like(out)
181+
else:
182+
out = dpt.empty(
183+
inds.shape, dtype=chcs[0].dtype, usm_type=usm_type, sycl_queue=q
184+
)
185+
186+
_manager = dpu.SequentialOrderManager[q]
187+
dep_evs = _manager.submitted_events
188+
189+
h_ev, choose_ev = indexing_ext._choose(inds, chcs, out, mode, q, dep_evs)
190+
_manager.add_event_pair(h_ev, choose_ev)
191+
192+
return out
193+
194+
150195
def choose(x, choices, out=None, mode="wrap"):
151196
"""
152197
Construct an array from an index array and a set of arrays to choose from.
@@ -218,59 +263,10 @@ def choose(x, choices, out=None, mode="wrap"):
218263
arrs_broadcast = dpt.broadcast_arrays(inds, *choices)
219264
inds = arrs_broadcast[0]
220265
choices = tuple(arrs_broadcast[1:])
221-
res_sh = inds.shape
222266

223-
orig_out = out
224-
if out is not None:
225-
dpnp.check_supported_arrays_type(out)
226-
out = dpnp.get_usm_ndarray(out)
227-
228-
if not out.flags.writable:
229-
raise ValueError("provided `out` array is read-only")
230-
231-
if out.shape != res_sh:
232-
raise ValueError(
233-
"The shape of input and output arrays are inconsistent. "
234-
f"Expected output shape is {res_sh}, got {out.shape}"
235-
)
236-
237-
if res_dt != out.dtype:
238-
raise ValueError(
239-
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
240-
)
241-
242-
if dpu.get_execution_queue((x.sycl_queue, out.sycl_queue)) is None:
243-
raise dpu.ExecutionPlacementError(
244-
"Input and output allocation queues are not compatible"
245-
)
246-
247-
if ti._array_overlap(x, out) or any(
248-
ti._array_overlap(out, chc) for chc in choices
249-
):
250-
# Allocate a temporary buffer to avoid memory overlapping.
251-
out = dpt.empty_like(out)
252-
else:
253-
out = dpt.empty(
254-
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
255-
)
267+
res = _choose_run(inds, choices, exec_q, res_usm_type, out=out, mode=mode)
256268

257-
_manager = dpu.SequentialOrderManager[exec_q]
258-
dep_evs = _manager.submitted_events
259-
260-
h_ev, choose_ev = indexing_ext._choose(
261-
inds, choices, out, mode, exec_q, dep_evs
262-
)
263-
_manager.add_event_pair(h_ev, choose_ev)
264-
265-
if not (orig_out is None or orig_out is out):
266-
# Copy the out data from temporary buffer to original memory
267-
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
268-
src=out, dst=orig_out, sycl_queue=exec_q, depends=[choose_ev]
269-
)
270-
_manager.add_event_pair(ht_copy_ev, cpy_ev)
271-
out = orig_out
272-
273-
return dpnp.get_result_array(out)
269+
return dpnp.get_result_array(res, out=out)
274270

275271

276272
def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):

0 commit comments

Comments
 (0)