3737
3838"""
3939
40+ # pylint: disable=protected-access
41+
4042import operator
4143
4244import dpctl .tensor as dpt
@@ -158,6 +160,71 @@ def choose(x1, choices, out=None, mode="raise"):
158160 return call_origin (numpy .choose , x1 , choices , out , mode )
159161
160162
163+ def _take_1d_index (x , inds , axis , q , usm_type , out = None ):
164+ # arg validation assumed done by caller
165+ x_sh = x .shape
166+ ind0 = inds [0 ]
167+ axis_end = axis + 1
168+ if 0 in x_sh [axis :axis_end ] and ind0 .size != 0 :
169+ raise IndexError ("cannot take non-empty indices from an empty axis" )
170+ res_sh = x_sh [:axis ] + ind0 .shape + x_sh [axis_end :]
171+
172+ orig_out = out
173+ if out is not None :
174+ dpnp .check_supported_arrays_type (out )
175+ out = dpnp .get_usm_ndarray (out )
176+
177+ if not out .flags .writable :
178+ raise ValueError ("provided `out` array is read-only" )
179+
180+ if out .shape != res_sh :
181+ raise ValueError (
182+ "The shape of input and output arrays are inconsistent. "
183+ f"Expected output shape is { res_sh } , got { out .shape } "
184+ )
185+
186+ if x .dtype != out .dtype :
187+ raise ValueError (
188+ f"Output array of type { x .dtype } is needed, " f"got { out .dtype } "
189+ )
190+
191+ if dpu .get_execution_queue ((q , out .sycl_queue )) is None :
192+ raise dpu .ExecutionPlacementError (
193+ "Input and output allocation queues are not compatible"
194+ )
195+
196+ if ti ._array_overlap (x , out ):
197+ # Allocate a temporary buffer to avoid memory overlapping.
198+ out = dpt .empty_like (out )
199+ else :
200+ out = dpt .empty (res_sh , dtype = x .dtype , usm_type = usm_type , sycl_queue = q )
201+
202+ _manager = dpu .SequentialOrderManager [q ]
203+ dep_evs = _manager .submitted_events
204+
205+ # always use wrap mode here
206+ h_ev , take_ev = ti ._take (
207+ src = x ,
208+ ind = inds ,
209+ dst = out ,
210+ axis_start = axis ,
211+ mode = 0 ,
212+ sycl_queue = q ,
213+ depends = dep_evs ,
214+ )
215+ _manager .add_event_pair (h_ev , take_ev )
216+
217+ if not (orig_out is None or orig_out is out ):
218+ # Copy the out data from temporary buffer to original memory
219+ ht_copy_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
220+ src = out , dst = orig_out , sycl_queue = q , depends = [take_ev ]
221+ )
222+ _manager .add_event_pair (ht_copy_ev , cpy_ev )
223+ out = orig_out
224+
225+ return out
226+
227+
161228def compress (condition , a , axis = None , out = None ):
162229 """
163230 Return selected slices of an array along given axis.
@@ -195,8 +262,7 @@ def compress(condition, a, axis=None, out=None):
195262 if a .ndim != 1 :
196263 a = dpnp .ravel (a )
197264 axis = 0
198- else :
199- axis = normalize_axis_index (operator .index (axis ), a .ndim )
265+ axis = normalize_axis_index (operator .index (axis ), a .ndim )
200266
201267 a_ary = dpnp .get_usm_ndarray (a )
202268 if not dpnp .is_supported_array_type (condition ):
@@ -216,7 +282,7 @@ def compress(condition, a, axis=None, out=None):
216282 usm_types_ = [a_ary .usm_type , cond_ary .usm_type ]
217283 if not cond_ary .ndim == 1 :
218284 raise ValueError (
219- "`condition` must be a 1-D array or un-nested " " sequence"
285+ "`condition` must be a 1-D array or un-nested sequence"
220286 )
221287
222288 res_usm_type = dpu .get_coerced_usm_type (usm_types_ )
@@ -226,74 +292,12 @@ def compress(condition, a, axis=None, out=None):
226292 "arrays must be allocated on the same SYCL queue"
227293 )
228294
229- inds = _nonzero_impl (cond_ary ) # synchronizes
230-
231- res_dt = a_ary .dtype
232- ind0 = inds [0 ]
233- a_sh = a_ary .shape
234- axis_end = axis + 1
235- if 0 in a_sh [axis :axis_end ] and ind0 .size != 0 :
236- raise IndexError ("cannot take non-empty indices from an empty axis" )
237- res_sh = a_sh [:axis ] + ind0 .shape + a_sh [axis_end :]
238-
239- orig_out = out
240- if out is not None :
241- dpnp .check_supported_arrays_type (out )
242- out = dpnp .get_usm_ndarray (out )
243-
244- if not out .flags .writable :
245- raise ValueError ("provided `out` array is read-only" )
246-
247- if out .shape != res_sh :
248- raise ValueError (
249- "The shape of input and output arrays are inconsistent. "
250- f"Expected output shape is { res_sh } , got { out .shape } "
251- )
252-
253- if res_dt != out .dtype :
254- raise ValueError (
255- f"Output array of type { res_dt } is needed, " f"got { out .dtype } "
256- )
257-
258- if dpu .get_execution_queue ((a_ary .sycl_queue , out .sycl_queue )) is None :
259- raise dpu .ExecutionPlacementError (
260- "Input and output allocation queues are not compatible"
261- )
262-
263- if ti ._array_overlap (a_ary , out ):
264- # Allocate a temporary buffer to avoid memory overlapping.
265- out = dpt .empty_like (out )
266- else :
267- out = dpt .empty (
268- res_sh , dtype = res_dt , usm_type = res_usm_type , sycl_queue = exec_q
269- )
270-
271- if out .size == 0 :
272- return out
295+ # _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
296+ inds = _nonzero_impl (cond_ary )
273297
274- _manager = dpu .SequentialOrderManager [exec_q ]
275- dep_evs = _manager .submitted_events
276-
277- h_ev , take_ev = ti ._take (
278- src = a_ary ,
279- ind = inds ,
280- dst = out ,
281- axis_start = axis ,
282- mode = 0 ,
283- sycl_queue = exec_q ,
284- depends = dep_evs ,
298+ return dpnp .get_result_array (
299+ _take_1d_index (a_ary , inds , axis , exec_q , res_usm_type , out )
285300 )
286- _manager .add_event_pair (h_ev , take_ev )
287-
288- if not (orig_out is None or orig_out is out ):
289- # Copy the out data from temporary buffer to original memory
290- ht_copy_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
291- src = out , dst = orig_out , sycl_queue = exec_q , depends = [take_ev ]
292- )
293- _manager .add_event_pair (ht_copy_ev , cpy_ev )
294- out = orig_out
295-
296- return dpnp .get_result_array (out )
297301
298302
299303def diag_indices (n , ndim = 2 , device = None , usm_type = "device" , sycl_queue = None ):
0 commit comments