@@ -40,7 +40,7 @@ def _get_indexing_mode(name):
4040 )
4141
4242
43- def take (x , indices , / , * , axis = None , mode = "wrap" ):
43+ def take (x , indices , / , * , axis = None , out = None , mode = "wrap" ):
4444 """take(x, indices, axis=None, mode="wrap")
4545
4646 Takes elements from an array along a given axis at given indices.
@@ -54,6 +54,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
5454 The axis along which the values will be selected.
5555 If ``x`` is one-dimensional, this argument is optional.
5656 Default: ``None``.
57+ out (Optional[usm_ndarray]):
58+ Output array to populate. Array must have the correct
59+ shape and the expected data type.
5760 mode (str, optional):
5861 How out-of-bounds indices will be handled. Possible values
5962 are:
@@ -121,18 +124,53 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
121124 raise ValueError ("`axis` must be 0 for an array of dimension 0." )
122125 res_shape = indices .shape
123126
124- res = dpt .empty (
125- res_shape , dtype = x .dtype , usm_type = res_usm_type , sycl_queue = exec_q
126- )
127+ dt = x .dtype
128+
129+ orig_out = out
130+ if out is not None :
131+ if not isinstance (out , dpt .usm_ndarray ):
132+ raise TypeError (
133+ f"output array must be of usm_ndarray type, got { type (out )} "
134+ )
135+ if not out .flags .writable :
136+ raise ValueError ("provided `out` array is read-only" )
137+
138+ if out .shape != res_shape :
139+ raise ValueError (
140+ "The shape of input and output arrays are inconsistent. "
141+ f"Expected output shape is { res_shape } , got { out .shape } "
142+ )
143+ if dt != out .dtype :
144+ raise ValueError (
145+ f"Output array of type { dt } is needed, " f"got { out .dtype } "
146+ )
147+ if dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None :
148+ raise dpctl .utils .ExecutionPlacementError (
149+ "Input and output allocation queues are not compatible"
150+ )
151+ if ti ._array_overlap (x , out ):
152+ out = dpt .empty_like (out )
153+ else :
154+ out = dpt .empty (
155+ res_shape , dtype = dt , usm_type = res_usm_type , sycl_queue = exec_q
156+ )
127157
128158 _manager = dpctl .utils .SequentialOrderManager [exec_q ]
129159 deps_ev = _manager .submitted_events
130160 hev , take_ev = ti ._take (
131- x , (indices ,), res , axis , mode , sycl_queue = exec_q , depends = deps_ev
161+ x , (indices ,), out , axis , mode , sycl_queue = exec_q , depends = deps_ev
132162 )
133163 _manager .add_event_pair (hev , take_ev )
134164
135- return res
165+ if not (orig_out is None or out is orig_out ):
166+ # Copy the out data from temporary buffer to original memory
167+ ht_e_cpy , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
168+ src = out , dst = orig_out , sycl_queue = exec_q , depends = [take_ev ]
169+ )
170+ _manager .add_event_pair (ht_e_cpy , cpy_ev )
171+ out = orig_out
172+
173+ return out
136174
137175
138176def put (x , indices , vals , / , * , axis = None , mode = "wrap" ):
0 commit comments