3737
3838"""
3939
40+ # pylint: disable=protected-access
41+
4042import operator
4143
4244import dpctl .tensor as dpt
@@ -159,6 +161,71 @@ def choose(x1, choices, out=None, mode="raise"):
159161 return call_origin (numpy .choose , x1 , choices , out , mode )
160162
161163
164+ def _take_1d_index (x , inds , axis , q , usm_type , out = None ):
165+ # arg validation assumed done by caller
166+ x_sh = x .shape
167+ ind0 = inds [0 ]
168+ axis_end = axis + 1
169+ if 0 in x_sh [axis :axis_end ] and ind0 .size != 0 :
170+ raise IndexError ("cannot take non-empty indices from an empty axis" )
171+ res_sh = x_sh [:axis ] + ind0 .shape + x_sh [axis_end :]
172+
173+ orig_out = out
174+ if out is not None :
175+ dpnp .check_supported_arrays_type (out )
176+ out = dpnp .get_usm_ndarray (out )
177+
178+ if not out .flags .writable :
179+ raise ValueError ("provided `out` array is read-only" )
180+
181+ if out .shape != res_sh :
182+ raise ValueError (
183+ "The shape of input and output arrays are inconsistent. "
184+ f"Expected output shape is { res_sh } , got { out .shape } "
185+ )
186+
187+ if x .dtype != out .dtype :
188+ raise ValueError (
189+ f"Output array of type { x .dtype } is needed, " f"got { out .dtype } "
190+ )
191+
192+ if dpu .get_execution_queue ((q , out .sycl_queue )) is None :
193+ raise dpu .ExecutionPlacementError (
194+ "Input and output allocation queues are not compatible"
195+ )
196+
197+ if ti ._array_overlap (x , out ):
198+ # Allocate a temporary buffer to avoid memory overlapping.
199+ out = dpt .empty_like (out )
200+ else :
201+ out = dpt .empty (res_sh , dtype = x .dtype , usm_type = usm_type , sycl_queue = q )
202+
203+ _manager = dpu .SequentialOrderManager [q ]
204+ dep_evs = _manager .submitted_events
205+
206+ # always use wrap mode here
207+ h_ev , take_ev = ti ._take (
208+ src = x ,
209+ ind = inds ,
210+ dst = out ,
211+ axis_start = axis ,
212+ mode = 0 ,
213+ sycl_queue = q ,
214+ depends = dep_evs ,
215+ )
216+ _manager .add_event_pair (h_ev , take_ev )
217+
218+ if not (orig_out is None or orig_out is out ):
219+ # Copy the out data from temporary buffer to original memory
220+ ht_copy_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
221+ src = out , dst = orig_out , sycl_queue = q , depends = [take_ev ]
222+ )
223+ _manager .add_event_pair (ht_copy_ev , cpy_ev )
224+ out = orig_out
225+
226+ return out
227+
228+
162229def compress (condition , a , axis = None , out = None ):
163230 """
164231 Return selected slices of an array along given axis.
@@ -196,8 +263,7 @@ def compress(condition, a, axis=None, out=None):
196263 if a .ndim != 1 :
197264 a = dpnp .ravel (a )
198265 axis = 0
199- else :
200- axis = normalize_axis_index (operator .index (axis ), a .ndim )
266+ axis = normalize_axis_index (operator .index (axis ), a .ndim )
201267
202268 a_ary = dpnp .get_usm_ndarray (a )
203269 if not dpnp .is_supported_array_type (condition ):
@@ -217,7 +283,7 @@ def compress(condition, a, axis=None, out=None):
217283 usm_types_ = [a_ary .usm_type , cond_ary .usm_type ]
218284 if not cond_ary .ndim == 1 :
219285 raise ValueError (
220- "`condition` must be a 1-D array or un-nested " " sequence"
286+ "`condition` must be a 1-D array or un-nested sequence"
221287 )
222288
223289 res_usm_type = dpu .get_coerced_usm_type (usm_types_ )
@@ -227,74 +293,12 @@ def compress(condition, a, axis=None, out=None):
227293 "arrays must be allocated on the same SYCL queue"
228294 )
229295
230- inds = _nonzero_impl (cond_ary ) # synchronizes
231-
232- res_dt = a_ary .dtype
233- ind0 = inds [0 ]
234- a_sh = a_ary .shape
235- axis_end = axis + 1
236- if 0 in a_sh [axis :axis_end ] and ind0 .size != 0 :
237- raise IndexError ("cannot take non-empty indices from an empty axis" )
238- res_sh = a_sh [:axis ] + ind0 .shape + a_sh [axis_end :]
239-
240- orig_out = out
241- if out is not None :
242- dpnp .check_supported_arrays_type (out )
243- out = dpnp .get_usm_ndarray (out )
244-
245- if not out .flags .writable :
246- raise ValueError ("provided `out` array is read-only" )
247-
248- if out .shape != res_sh :
249- raise ValueError (
250- "The shape of input and output arrays are inconsistent. "
251- f"Expected output shape is { res_sh } , got { out .shape } "
252- )
253-
254- if res_dt != out .dtype :
255- raise ValueError (
256- f"Output array of type { res_dt } is needed, " f"got { out .dtype } "
257- )
258-
259- if dpu .get_execution_queue ((a_ary .sycl_queue , out .sycl_queue )) is None :
260- raise dpu .ExecutionPlacementError (
261- "Input and output allocation queues are not compatible"
262- )
263-
264- if ti ._array_overlap (a_ary , out ):
265- # Allocate a temporary buffer to avoid memory overlapping.
266- out = dpt .empty_like (out )
267- else :
268- out = dpt .empty (
269- res_sh , dtype = res_dt , usm_type = res_usm_type , sycl_queue = exec_q
270- )
271-
272- if out .size == 0 :
273- return out
296+ # _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
297+ inds = _nonzero_impl (cond_ary )
274298
275- _manager = dpu .SequentialOrderManager [exec_q ]
276- dep_evs = _manager .submitted_events
277-
278- h_ev , take_ev = ti ._take (
279- src = a_ary ,
280- ind = inds ,
281- dst = out ,
282- axis_start = axis ,
283- mode = 0 ,
284- sycl_queue = exec_q ,
285- depends = dep_evs ,
299+ return dpnp .get_result_array (
300+ _take_1d_index (a_ary , inds , axis , exec_q , res_usm_type , out )
286301 )
287- _manager .add_event_pair (h_ev , take_ev )
288-
289- if not (orig_out is None or orig_out is out ):
290- # Copy the out data from temporary buffer to original memory
291- ht_copy_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
292- src = out , dst = orig_out , sycl_queue = exec_q , depends = [take_ev ]
293- )
294- _manager .add_event_pair (ht_copy_ev , cpy_ev )
295- out = orig_out
296-
297- return dpnp .get_result_array (out )
298302
299303
300304def diag_indices (n , ndim = 2 , device = None , usm_type = "device" , sycl_queue = None ):
0 commit comments