4646import dpctl .utils as dpu
4747import numpy
4848from dpctl .tensor ._copy_utils import _nonzero_impl
49+ from dpctl .tensor ._indexing_functions import _get_indexing_mode
4950from dpctl .tensor ._numpy_helper import normalize_axis_index
5051
5152import dpnp
@@ -160,14 +161,13 @@ def choose(x1, choices, out=None, mode="raise"):
160161 return call_origin (numpy .choose , x1 , choices , out , mode )
161162
162163
163- def _take_1d_index (x , inds , axis , q , usm_type , out = None ):
164+ def _take_index (x , inds , axis , q , usm_type , out = None , mode = 0 ):
164165 # arg validation assumed done by caller
165166 x_sh = x .shape
166- ind0 = inds [0 ]
167167 axis_end = axis + 1
168- if 0 in x_sh [axis :axis_end ] and ind0 .size != 0 :
168+ if 0 in x_sh [axis :axis_end ] and inds .size != 0 :
169169 raise IndexError ("cannot take non-empty indices from an empty axis" )
170- res_sh = x_sh [:axis ] + ind0 .shape + x_sh [axis_end :]
170+ res_sh = x_sh [:axis ] + inds .shape + x_sh [axis_end :]
171171
172172 orig_out = None
173173 if out is not None :
@@ -201,13 +201,12 @@ def _take_1d_index(x, inds, axis, q, usm_type, out=None):
201201 _manager = dpu .SequentialOrderManager [q ]
202202 dep_evs = _manager .submitted_events
203203
204- # always use wrap mode here
205204 h_ev , take_ev = ti ._take (
206205 src = x ,
207- ind = inds ,
206+ ind = ( inds ,) ,
208207 dst = out ,
209208 axis_start = axis ,
210- mode = 0 ,
209+ mode = mode ,
211210 sycl_queue = q ,
212211 depends = dep_evs ,
213212 )
@@ -318,7 +317,8 @@ def compress(condition, a, axis=None, out=None):
318317 inds = _nonzero_impl (cond_ary )
319318
320319 return dpnp .get_result_array (
321- _take_1d_index (a_ary , inds , axis , exec_q , res_usm_type , out ), out = out
320+ _take_index (a_ary , inds [0 ], axis , exec_q , res_usm_type , out = out ),
321+ out = out ,
322322 )
323323
324324
@@ -1902,8 +1902,8 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
19021902
19031903 """
19041904
1905- if mode not in ( "wrap" , "clip" ):
1906- raise ValueError ( f"` mode` must be 'wrap' or 'clip', but got ` { mode } `." )
1905+ # sets mode to 0 for "wrap" and 1 for "clip", raises otherwise
1906+ mode = _get_indexing_mode ( mode )
19071907
19081908 usm_a = dpnp .get_usm_ndarray (a )
19091909 if not dpnp .is_supported_array_type (indices ):
@@ -1913,34 +1913,28 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
19131913 else :
19141914 usm_ind = dpnp .get_usm_ndarray (indices )
19151915
1916+ res_usm_type , exec_q = get_usm_allocations ([usm_a , usm_ind ])
1917+
19161918 a_ndim = a .ndim
19171919 if axis is None :
1918- res_shape = usm_ind .shape
1919-
19201920 if a_ndim > 1 :
1921- # dpt.take requires flattened input array
1921+ # flatten input array
19221922 usm_a = dpt .reshape (usm_a , - 1 )
1923+ axis = 0
19231924 elif a_ndim == 0 :
19241925 axis = normalize_axis_index (operator .index (axis ), 1 )
1925- res_shape = usm_ind .shape
19261926 else :
19271927 axis = normalize_axis_index (operator .index (axis ), a_ndim )
1928- a_sh = a .shape
1929- res_shape = a_sh [:axis ] + usm_ind .shape + a_sh [axis + 1 :]
1930-
1931- if usm_ind .ndim != 1 :
1932- # dpt.take supports only 1-D array of indices
1933- usm_ind = dpt .reshape (usm_ind , - 1 )
19341928
19351929 if not dpnp .issubdtype (usm_ind .dtype , dpnp .integer ):
19361930 # dpt.take supports only integer dtype for array of indices
19371931 usm_ind = dpt .astype (usm_ind , dpnp .intp , copy = False , casting = "safe" )
19381932
1939- usm_res = dpt .take (usm_a , usm_ind , axis = axis , mode = mode )
1933+ usm_res = _take_index (
1934+ usm_a , usm_ind , axis , exec_q , res_usm_type , out = out , mode = mode
1935+ )
19401936
1941- # need to reshape the result if shape of indices array was changed
1942- result = dpnp .reshape (usm_res , res_shape )
1943- return dpnp .get_result_array (result , out )
1937+ return dpnp .get_result_array (usm_res , out = out )
19441938
19451939
19461940def take_along_axis (a , indices , axis , mode = "wrap" ):
0 commit comments