3737
3838"""
3939
40+ # pylint: disable=protected-access
41+
4042import operator
4143
44+ import dpctl
4245import dpctl .tensor as dpt
4346import dpctl .tensor ._tensor_impl as ti
4447import dpctl .utils as dpu
@@ -158,6 +161,78 @@ def choose(x1, choices, out=None, mode="raise"):
158161 return call_origin (numpy .choose , x1 , choices , out , mode )
159162
160163
164+ def _take_1d_index (
165+ x : dpt .usm_ndarray ,
166+ inds : tuple [dpt .usm_ndarray ],
167+ axis : int ,
168+ q : dpctl .SyclQueue ,
169+ usm_type : str ,
170+ out : dpt .usm_ndarray | None = None ,
171+ ) -> dpt .usm_ndarray :
172+ # arg validation assumed done by caller
173+ x_sh = x .shape
174+ ind0 = inds [0 ]
175+ axis_end = axis + 1
176+ if 0 in x_sh [axis :axis_end ] and ind0 .size != 0 :
177+ raise IndexError ("cannot take non-empty indices from an empty axis" )
178+ res_sh = x_sh [:axis ] + ind0 .shape + x_sh [axis_end :]
179+
180+ orig_out = out
181+ if out is not None :
182+ dpnp .check_supported_arrays_type (out )
183+ out = dpnp .get_usm_ndarray (out )
184+
185+ if not out .flags .writable :
186+ raise ValueError ("provided `out` array is read-only" )
187+
188+ if out .shape != res_sh :
189+ raise ValueError (
190+ "The shape of input and output arrays are inconsistent. "
191+ f"Expected output shape is { res_sh } , got { out .shape } "
192+ )
193+
194+ if x .dtype != out .dtype :
195+ raise ValueError (
196+ f"Output array of type { x .dtype } is needed, " f"got { out .dtype } "
197+ )
198+
199+ if dpu .get_execution_queue ((q , out .sycl_queue )) is None :
200+ raise dpu .ExecutionPlacementError (
201+ "Input and output allocation queues are not compatible"
202+ )
203+
204+ if ti ._array_overlap (x , out ):
205+ # Allocate a temporary buffer to avoid memory overlapping.
206+ out = dpt .empty_like (out )
207+ else :
208+ out = dpt .empty (res_sh , dtype = x .dtype , usm_type = usm_type , sycl_queue = q )
209+
210+ _manager = dpu .SequentialOrderManager [q ]
211+ dep_evs = _manager .submitted_events
212+
213+ # always use wrap mode here
214+ h_ev , take_ev = ti ._take (
215+ src = x ,
216+ ind = inds ,
217+ dst = out ,
218+ axis_start = axis ,
219+ mode = 0 ,
220+ sycl_queue = q ,
221+ depends = dep_evs ,
222+ )
223+ _manager .add_event_pair (h_ev , take_ev )
224+
225+ if not (orig_out is None or orig_out is out ):
226+ # Copy the out data from temporary buffer to original memory
227+ ht_copy_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
228+ src = out , dst = orig_out , sycl_queue = q , depends = [take_ev ]
229+ )
230+ _manager .add_event_pair (ht_copy_ev , cpy_ev )
231+ out = orig_out
232+
233+ return out
234+
235+
161236def compress (condition , a , axis = None , out = None ):
162237 """
163238 Return selected slices of an array along given axis.
@@ -195,8 +270,7 @@ def compress(condition, a, axis=None, out=None):
195270 if a .ndim != 1 :
196271 a = dpnp .ravel (a )
197272 axis = 0
198- else :
199- axis = normalize_axis_index (operator .index (axis ), a .ndim )
273+ axis = normalize_axis_index (operator .index (axis ), a .ndim )
200274
201275 a_ary = dpnp .get_usm_ndarray (a )
202276 if not dpnp .is_supported_array_type (condition ):
@@ -216,7 +290,7 @@ def compress(condition, a, axis=None, out=None):
216290 usm_types_ = [a_ary .usm_type , cond_ary .usm_type ]
217291 if not cond_ary .ndim == 1 :
218292 raise ValueError (
219- "`condition` must be a 1-D array or un-nested " " sequence"
293+ "`condition` must be a 1-D array or un-nested sequence"
220294 )
221295
222296 res_usm_type = dpu .get_coerced_usm_type (usm_types_ )
@@ -226,74 +300,12 @@ def compress(condition, a, axis=None, out=None):
226300 "arrays must be allocated on the same SYCL queue"
227301 )
228302
229- inds = _nonzero_impl (cond_ary ) # synchronizes
230-
231- res_dt = a_ary .dtype
232- ind0 = inds [0 ]
233- a_sh = a_ary .shape
234- axis_end = axis + 1
235- if 0 in a_sh [axis :axis_end ] and ind0 .size != 0 :
236- raise IndexError ("cannot take non-empty indices from an empty axis" )
237- res_sh = a_sh [:axis ] + ind0 .shape + a_sh [axis_end :]
238-
239- orig_out = out
240- if out is not None :
241- dpnp .check_supported_arrays_type (out )
242- out = dpnp .get_usm_ndarray (out )
243-
244- if not out .flags .writable :
245- raise ValueError ("provided `out` array is read-only" )
246-
247- if out .shape != res_sh :
248- raise ValueError (
249- "The shape of input and output arrays are inconsistent. "
250- f"Expected output shape is { res_sh } , got { out .shape } "
251- )
252-
253- if res_dt != out .dtype :
254- raise ValueError (
255- f"Output array of type { res_dt } is needed, " f"got { out .dtype } "
256- )
257-
258- if dpu .get_execution_queue ((a_ary .sycl_queue , out .sycl_queue )) is None :
259- raise dpu .ExecutionPlacementError (
260- "Input and output allocation queues are not compatible"
261- )
262-
263- if ti ._array_overlap (a_ary , out ):
264- # Allocate a temporary buffer to avoid memory overlapping.
265- out = dpt .empty_like (out )
266- else :
267- out = dpt .empty (
268- res_sh , dtype = res_dt , usm_type = res_usm_type , sycl_queue = exec_q
269- )
270-
271- if out .size == 0 :
272- return out
303+ # _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
304+ inds = _nonzero_impl (cond_ary )
273305
274- _manager = dpu .SequentialOrderManager [exec_q ]
275- dep_evs = _manager .submitted_events
276-
277- h_ev , take_ev = ti ._take (
278- src = a_ary ,
279- ind = inds ,
280- dst = out ,
281- axis_start = axis ,
282- mode = 0 ,
283- sycl_queue = exec_q ,
284- depends = dep_evs ,
306+ return dpnp .get_result_array (
307+ _take_1d_index (a_ary , inds , axis , exec_q , res_usm_type , out )
285308 )
286- _manager .add_event_pair (h_ev , take_ev )
287-
288- if not (orig_out is None or orig_out is out ):
289- # Copy the out data from temporary buffer to original memory
290- ht_copy_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
291- src = out , dst = orig_out , sycl_queue = exec_q , depends = [take_ev ]
292- )
293- _manager .add_event_pair (ht_copy_ev , cpy_ev )
294- out = orig_out
295-
296- return dpnp .get_result_array (out )
297309
298310
299311def diag_indices (n , ndim = 2 , device = None , usm_type = "device" , sycl_queue = None ):
0 commit comments