4646import dpctl .utils as dpu
4747import numpy
4848from dpctl .tensor ._copy_utils import _nonzero_impl
49+ from dpctl .tensor ._indexing_functions import _get_indexing_mode
4950from dpctl .tensor ._numpy_helper import normalize_axis_index
5051
5152import dpnp
@@ -161,14 +162,13 @@ def choose(x1, choices, out=None, mode="raise"):
161162 return call_origin (numpy .choose , x1 , choices , out , mode )
162163
163164
164- def _take_1d_index (x , inds , axis , q , usm_type , out = None ):
165+ def _take_index (x , inds , axis , q , usm_type , out = None , mode = 0 ):
165166 # arg validation assumed done by caller
166167 x_sh = x .shape
167- ind0 = inds [0 ]
168168 axis_end = axis + 1
169- if 0 in x_sh [axis :axis_end ] and ind0 .size != 0 :
169+ if 0 in x_sh [axis :axis_end ] and inds .size != 0 :
170170 raise IndexError ("cannot take non-empty indices from an empty axis" )
171- res_sh = x_sh [:axis ] + ind0 .shape + x_sh [axis_end :]
171+ res_sh = x_sh [:axis ] + inds .shape + x_sh [axis_end :]
172172
173173 orig_out = None
174174 if out is not None :
@@ -202,13 +202,12 @@ def _take_1d_index(x, inds, axis, q, usm_type, out=None):
202202 _manager = dpu .SequentialOrderManager [q ]
203203 dep_evs = _manager .submitted_events
204204
205- # always use wrap mode here
206205 h_ev , take_ev = ti ._take (
207206 src = x ,
208- ind = inds ,
207+ ind = ( inds ,) ,
209208 dst = out ,
210209 axis_start = axis ,
211- mode = 0 ,
210+ mode = mode ,
212211 sycl_queue = q ,
213212 depends = dep_evs ,
214213 )
@@ -319,7 +318,8 @@ def compress(condition, a, axis=None, out=None):
319318 inds = _nonzero_impl (cond_ary )
320319
321320 return dpnp .get_result_array (
322- _take_1d_index (a_ary , inds , axis , exec_q , res_usm_type , out ), out = out
321+ _take_index (a_ary , inds [0 ], axis , exec_q , res_usm_type , out = out ),
322+ out = out ,
323323 )
324324
325325
@@ -1974,8 +1974,8 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
19741974
19751975 """
19761976
1977- if mode not in ( "wrap" , "clip" ):
1978- raise ValueError ( f"` mode` must be 'wrap' or 'clip', but got ` { mode } `." )
1977+ # sets mode to 0 for "wrap" and 1 for "clip", raises otherwise
1978+ mode = _get_indexing_mode ( mode )
19791979
19801980 usm_a = dpnp .get_usm_ndarray (a )
19811981 if not dpnp .is_supported_array_type (indices ):
@@ -1985,34 +1985,28 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
19851985 else :
19861986 usm_ind = dpnp .get_usm_ndarray (indices )
19871987
1988+ res_usm_type , exec_q = get_usm_allocations ([usm_a , usm_ind ])
1989+
19881990 a_ndim = a .ndim
19891991 if axis is None :
1990- res_shape = usm_ind .shape
1991-
19921992 if a_ndim > 1 :
1993- # dpt.take requires flattened input array
1993+ # flatten input array
19941994 usm_a = dpt .reshape (usm_a , - 1 )
1995+ axis = 0
19951996 elif a_ndim == 0 :
19961997 axis = normalize_axis_index (operator .index (axis ), 1 )
1997- res_shape = usm_ind .shape
19981998 else :
19991999 axis = normalize_axis_index (operator .index (axis ), a_ndim )
2000- a_sh = a .shape
2001- res_shape = a_sh [:axis ] + usm_ind .shape + a_sh [axis + 1 :]
2002-
2003- if usm_ind .ndim != 1 :
2004- # dpt.take supports only 1-D array of indices
2005- usm_ind = dpt .reshape (usm_ind , - 1 )
20062000
20072001 if not dpnp .issubdtype (usm_ind .dtype , dpnp .integer ):
20082002 # dpt.take supports only integer dtype for array of indices
20092003 usm_ind = dpt .astype (usm_ind , dpnp .intp , copy = False , casting = "safe" )
20102004
2011- usm_res = dpt .take (usm_a , usm_ind , axis = axis , mode = mode )
2005+ usm_res = _take_index (
2006+ usm_a , usm_ind , axis , exec_q , res_usm_type , out = out , mode = mode
2007+ )
20122008
2013- # need to reshape the result if shape of indices array was changed
2014- result = dpnp .reshape (usm_res , res_shape )
2015- return dpnp .get_result_array (result , out )
2009+ return dpnp .get_result_array (usm_res , out = out )
20162010
20172011
20182012def take_along_axis (a , indices , axis , mode = "wrap" ):
0 commit comments