Skip to content

Commit 9641064

Browse files
committed
Implement dpnp.compress
1 parent cd23361 commit 9641064

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@
4040
import operator
4141

4242
import dpctl.tensor as dpt
43+
import dpctl.tensor._tensor_impl as ti
44+
import dpctl.utils as dpu
4345
import numpy
46+
from dpctl.tensor._copy_utils import _nonzero_impl
4447
from dpctl.tensor._numpy_helper import normalize_axis_index
4548

4649
import dpnp
@@ -55,6 +58,7 @@
5558

5659
__all__ = [
5760
"choose",
61+
"compress",
5862
"diag_indices",
5963
"diag_indices_from",
6064
"diagonal",
@@ -155,6 +159,144 @@ def choose(x1, choices, out=None, mode="raise"):
155159
return call_origin(numpy.choose, x1, choices, out, mode)
156160

157161

162+
def compress(condition, a, axis=None, out=None):
163+
"""
164+
Return selected slices of an array along given axis.
165+
166+
For full documentation refer to :obj:`numpy.choose`.
167+
168+
Parameters
169+
----------
170+
condition : {array_like, dpnp.ndarray, usm_ndarray}
171+
Array that selects which entries to extract. If the length of
172+
`condition` is less than the size of `a` along `axis`, then
173+
the output is truncated to the length of `condition`.
174+
a : {dpnp.ndarray, usm_ndarray}
175+
Array to extract from.
176+
axis : {int}, optional
177+
Axis along which to extract slices. If `None`, works over the
178+
flattened array.
179+
out : {None, dpnp.ndarray, usm_ndarray}, optional
180+
If provided, the result will be placed in this array. It should
181+
be of the appropriate shape and dtype.
182+
Default: ``None``.
183+
184+
Returns
185+
-------
186+
out : dpnp.ndarray
187+
A copy of the slices of `a` where `condition` is True.
188+
189+
See also
190+
--------
191+
:obj:`dpnp.ndarray.compress` : Equivalent method.
192+
:obj:`dpnp.extract` : Equivalent function when working on 1-D arrays.
193+
"""
194+
dpnp.check_supported_arrays_type(a)
195+
if axis is None:
196+
if a.ndim != 1:
197+
a = dpnp.ravel(a)
198+
axis = 0
199+
else:
200+
axis = normalize_axis_index(operator.index(axis), a.ndim)
201+
202+
a_ary = dpnp.get_usm_ndarray(a)
203+
if not dpnp.is_supported_array_type(condition):
204+
usm_type = a_ary.usm_type
205+
q = a_ary.sycl_queue
206+
cond_ary = dpnp.as_usm_ndarray(
207+
condition,
208+
dtype=dpnp.bool,
209+
usm_type=usm_type,
210+
sycl_queue=q,
211+
)
212+
queues_ = [q]
213+
usm_types_ = [usm_type]
214+
else:
215+
cond_ary = dpnp.get_usm_ndarray(condition)
216+
queues_ = [a_ary.sycl_queue, cond_ary.sycl_queue]
217+
usm_types_ = [a_ary.usm_type, cond_ary.usm_type]
218+
if not cond_ary.ndim == 1:
219+
raise ValueError(
220+
"`condition` must be a 1-D array or un-nested " "sequence"
221+
)
222+
223+
res_usm_type = dpu.get_coerced_usm_type(usm_types_)
224+
exec_q = dpu.get_execution_queue(queues_)
225+
if exec_q is None:
226+
raise dpu.ExecutionPlacementError(
227+
"arrays must be allocated on the same SYCL queue"
228+
)
229+
230+
inds = _nonzero_impl(cond_ary) # synchronizes
231+
232+
res_dt = a_ary.dtype
233+
ind0 = inds[0]
234+
a_sh = a_ary.shape
235+
axis_end = axis + 1
236+
if 0 in a_sh[axis:axis_end] and ind0.size != 0:
237+
raise IndexError("cannot take non-empty indices from an empty axis")
238+
res_sh = a_sh[:axis] + ind0.shape + a_sh[axis_end:]
239+
240+
orig_out = out
241+
if out is not None:
242+
dpnp.check_supported_arrays_type(out)
243+
out = dpnp.get_usm_ndarray(out)
244+
245+
if not out.flags.writable:
246+
raise ValueError("provided `out` array is read-only")
247+
248+
if out.shape != res_sh:
249+
raise ValueError(
250+
"The shape of input and output arrays are inconsistent. "
251+
f"Expected output shape is {res_sh}, got {out.shape}"
252+
)
253+
254+
if res_dt != out.dtype:
255+
raise ValueError(
256+
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
257+
)
258+
259+
if dpu.get_execution_queue((a_ary.sycl_queue, out.sycl_queue)) is None:
260+
raise dpu.ExecutionPlacementError(
261+
"Input and output allocation queues are not compatible"
262+
)
263+
264+
if ti._array_overlap(a_ary, out):
265+
# Allocate a temporary buffer to avoid memory overlapping.
266+
out = dpt.empty_like(out)
267+
else:
268+
out = dpt.empty(
269+
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
270+
)
271+
272+
if out.size == 0:
273+
return out
274+
275+
_manager = dpu.SequentialOrderManager[exec_q]
276+
dep_evs = _manager.submitted_events
277+
278+
h_ev, take_ev = ti._take(
279+
src=a_ary,
280+
ind=inds,
281+
dst=out,
282+
axis_start=axis,
283+
mode=0,
284+
sycl_queue=exec_q,
285+
depends=dep_evs,
286+
)
287+
_manager.add_event_pair(h_ev, take_ev)
288+
289+
if not (orig_out is None or orig_out is out):
290+
# Copy the out data from temporary buffer to original memory
291+
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
292+
src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev]
293+
)
294+
_manager.add_event_pair(ht_copy_ev, cpy_ev)
295+
out = orig_out
296+
297+
return dpnp.get_result_array(out)
298+
299+
158300
def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):
159301
"""
160302
Return the indices to access the main diagonal of an array.

0 commit comments

Comments
 (0)