Skip to content

Commit e53b84e

Browse files
committed
Re-use _take_index for dpnp.take
Should slightly improve efficiency by escaping an additional copy where `out` is not `None` and flattening of indices
1 parent 49bfa0e commit e53b84e

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import dpctl.utils as dpu
4747
import numpy
4848
from dpctl.tensor._copy_utils import _nonzero_impl
49+
from dpctl.tensor._indexing_functions import _get_indexing_mode
4950
from dpctl.tensor._numpy_helper import normalize_axis_index
5051

5152
import dpnp
@@ -161,14 +162,13 @@ def choose(x1, choices, out=None, mode="raise"):
161162
return call_origin(numpy.choose, x1, choices, out, mode)
162163

163164

164-
def _take_1d_index(x, inds, axis, q, usm_type, out=None):
165+
def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
165166
# arg validation assumed done by caller
166167
x_sh = x.shape
167-
ind0 = inds[0]
168168
axis_end = axis + 1
169-
if 0 in x_sh[axis:axis_end] and ind0.size != 0:
169+
if 0 in x_sh[axis:axis_end] and inds.size != 0:
170170
raise IndexError("cannot take non-empty indices from an empty axis")
171-
res_sh = x_sh[:axis] + ind0.shape + x_sh[axis_end:]
171+
res_sh = x_sh[:axis] + inds.shape + x_sh[axis_end:]
172172

173173
orig_out = None
174174
if out is not None:
@@ -202,13 +202,12 @@ def _take_1d_index(x, inds, axis, q, usm_type, out=None):
202202
_manager = dpu.SequentialOrderManager[q]
203203
dep_evs = _manager.submitted_events
204204

205-
# always use wrap mode here
206205
h_ev, take_ev = ti._take(
207206
src=x,
208-
ind=inds,
207+
ind=(inds,),
209208
dst=out,
210209
axis_start=axis,
211-
mode=0,
210+
mode=mode,
212211
sycl_queue=q,
213212
depends=dep_evs,
214213
)
@@ -319,7 +318,8 @@ def compress(condition, a, axis=None, out=None):
319318
inds = _nonzero_impl(cond_ary)
320319

321320
return dpnp.get_result_array(
322-
_take_1d_index(a_ary, inds, axis, exec_q, res_usm_type, out), out=out
321+
_take_index(a_ary, inds[0], axis, exec_q, res_usm_type, out=out),
322+
out=out,
323323
)
324324

325325

@@ -1974,8 +1974,8 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
19741974
19751975
"""
19761976

1977-
if mode not in ("wrap", "clip"):
1978-
raise ValueError(f"`mode` must be 'wrap' or 'clip', but got `{mode}`.")
1977+
# sets mode to 0 for "wrap" and 1 for "clip", raises otherwise
1978+
mode = _get_indexing_mode(mode)
19791979

19801980
usm_a = dpnp.get_usm_ndarray(a)
19811981
if not dpnp.is_supported_array_type(indices):
@@ -1985,34 +1985,28 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
19851985
else:
19861986
usm_ind = dpnp.get_usm_ndarray(indices)
19871987

1988+
res_usm_type, exec_q = get_usm_allocations([usm_a, usm_ind])
1989+
19881990
a_ndim = a.ndim
19891991
if axis is None:
1990-
res_shape = usm_ind.shape
1991-
19921992
if a_ndim > 1:
1993-
# dpt.take requires flattened input array
1993+
# flatten input array
19941994
usm_a = dpt.reshape(usm_a, -1)
1995+
axis = 0
19951996
elif a_ndim == 0:
19961997
axis = normalize_axis_index(operator.index(axis), 1)
1997-
res_shape = usm_ind.shape
19981998
else:
19991999
axis = normalize_axis_index(operator.index(axis), a_ndim)
2000-
a_sh = a.shape
2001-
res_shape = a_sh[:axis] + usm_ind.shape + a_sh[axis + 1 :]
2002-
2003-
if usm_ind.ndim != 1:
2004-
# dpt.take supports only 1-D array of indices
2005-
usm_ind = dpt.reshape(usm_ind, -1)
20062000

20072001
if not dpnp.issubdtype(usm_ind.dtype, dpnp.integer):
20082002
# dpt.take supports only integer dtype for array of indices
20092003
usm_ind = dpt.astype(usm_ind, dpnp.intp, copy=False, casting="safe")
20102004

2011-
usm_res = dpt.take(usm_a, usm_ind, axis=axis, mode=mode)
2005+
usm_res = _take_index(
2006+
usm_a, usm_ind, axis, exec_q, res_usm_type, out=out, mode=mode
2007+
)
20122008

2013-
# need to reshape the result if shape of indices array was changed
2014-
result = dpnp.reshape(usm_res, res_shape)
2015-
return dpnp.get_result_array(result, out)
2009+
return dpnp.get_result_array(usm_res, out=out)
20162010

20172011

20182012
def take_along_axis(a, indices, axis, mode="wrap"):

0 commit comments

Comments
 (0)