Skip to content

Commit 0f44d42

Browse files
committed
Factor out keyword validation and kernel run out of choose
1 parent 519e0d4 commit 0f44d42

File tree

1 file changed

+47
-51
lines changed

1 file changed

+47
-51
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,51 @@ def _build_choices_list(choices):
146146
return choices, queues, usm_types
147147

148148

149+
def _choose_run(inds, chcs, q, usm_type, out=None, mode=0):
150+
# arg validation, broadcasting, type coercion assumed done by caller
151+
if out is not None:
152+
dpnp.check_supported_arrays_type(out)
153+
out = dpnp.get_usm_ndarray(out)
154+
155+
if not out.flags.writable:
156+
raise ValueError("provided `out` array is read-only")
157+
158+
if out.shape != inds.shape:
159+
raise ValueError(
160+
"The shape of input and output arrays are inconsistent. "
161+
f"Expected output shape is {inds.shape}, got {out.shape}"
162+
)
163+
164+
if chcs[0].dtype != out.dtype:
165+
raise ValueError(
166+
f"Output array of type {chcs[0].dtype} is needed, "
167+
f"got {out.dtype}"
168+
)
169+
170+
if dpu.get_execution_queue((q, out.sycl_queue)) is None:
171+
raise dpu.ExecutionPlacementError(
172+
"Input and output allocation queues are not compatible"
173+
)
174+
175+
if ti._array_overlap(inds, out) or any(
176+
ti._array_overlap(out, chc) for chc in chcs
177+
):
178+
# Allocate a temporary buffer to avoid memory overlapping.
179+
out = dpt.empty_like(out)
180+
else:
181+
out = dpt.empty(
182+
inds.shape, dtype=chcs[0].dtype, usm_type=usm_type, sycl_queue=q
183+
)
184+
185+
_manager = dpu.SequentialOrderManager[q]
186+
dep_evs = _manager.submitted_events
187+
188+
h_ev, choose_ev = indexing_ext._choose(inds, chcs, out, mode, q, dep_evs)
189+
_manager.add_event_pair(h_ev, choose_ev)
190+
191+
return out
192+
193+
149194
def choose(x, choices, out=None, mode="wrap"):
150195
"""
151196
Construct an array from an index array and a set of arrays to choose from.
@@ -217,59 +262,10 @@ def choose(x, choices, out=None, mode="wrap"):
217262
arrs_broadcast = dpt.broadcast_arrays(inds, *choices)
218263
inds = arrs_broadcast[0]
219264
choices = tuple(arrs_broadcast[1:])
220-
res_sh = inds.shape
221265

222-
orig_out = out
223-
if out is not None:
224-
dpnp.check_supported_arrays_type(out)
225-
out = dpnp.get_usm_ndarray(out)
226-
227-
if not out.flags.writable:
228-
raise ValueError("provided `out` array is read-only")
229-
230-
if out.shape != res_sh:
231-
raise ValueError(
232-
"The shape of input and output arrays are inconsistent. "
233-
f"Expected output shape is {res_sh}, got {out.shape}"
234-
)
235-
236-
if res_dt != out.dtype:
237-
raise ValueError(
238-
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
239-
)
240-
241-
if dpu.get_execution_queue((x.sycl_queue, out.sycl_queue)) is None:
242-
raise dpu.ExecutionPlacementError(
243-
"Input and output allocation queues are not compatible"
244-
)
245-
246-
if ti._array_overlap(x, out) or any(
247-
ti._array_overlap(out, chc) for chc in choices
248-
):
249-
# Allocate a temporary buffer to avoid memory overlapping.
250-
out = dpt.empty_like(out)
251-
else:
252-
out = dpt.empty(
253-
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
254-
)
266+
res = _choose_run(inds, choices, exec_q, res_usm_type, out=out, mode=mode)
255267

256-
_manager = dpu.SequentialOrderManager[exec_q]
257-
dep_evs = _manager.submitted_events
258-
259-
h_ev, choose_ev = indexing_ext._choose(
260-
inds, choices, out, mode, exec_q, dep_evs
261-
)
262-
_manager.add_event_pair(h_ev, choose_ev)
263-
264-
if not (orig_out is None or orig_out is out):
265-
# Copy the out data from temporary buffer to original memory
266-
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
267-
src=out, dst=orig_out, sycl_queue=exec_q, depends=[choose_ev]
268-
)
269-
_manager.add_event_pair(ht_copy_ev, cpy_ev)
270-
out = orig_out
271-
272-
return dpnp.get_result_array(out)
268+
return dpnp.get_result_array(res, out=out)
273269

274270

275271
def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):

0 commit comments

Comments
 (0)