Skip to content

Commit a488d96

Browse files
committed
Implement dpnp.compress
1 parent b0dc412 commit a488d96

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@
4040
import operator
4141

4242
import dpctl.tensor as dpt
43+
import dpctl.tensor._tensor_impl as ti
44+
import dpctl.utils as dpu
4345
import numpy
46+
from dpctl.tensor._copy_utils import _nonzero_impl
4447
from dpctl.tensor._numpy_helper import normalize_axis_index
4548

4649
import dpnp
@@ -55,6 +58,7 @@
5558

5659
__all__ = [
5760
"choose",
61+
"compress",
5862
"diag_indices",
5963
"diag_indices_from",
6064
"diagonal",
@@ -154,6 +158,144 @@ def choose(x1, choices, out=None, mode="raise"):
154158
return call_origin(numpy.choose, x1, choices, out, mode)
155159

156160

161+
def compress(condition, a, axis=None, out=None):
162+
"""
163+
Return selected slices of an array along given axis.
164+
165+
For full documentation refer to :obj:`numpy.choose`.
166+
167+
Parameters
168+
----------
169+
condition : {array_like, dpnp.ndarray, usm_ndarray}
170+
Array that selects which entries to extract. If the length of
171+
`condition` is less than the size of `a` along `axis`, then
172+
the output is truncated to the length of `condition`.
173+
a : {dpnp.ndarray, usm_ndarray}
174+
Array to extract from.
175+
axis : {int}, optional
176+
Axis along which to extract slices. If `None`, works over the
177+
flattened array.
178+
out : {None, dpnp.ndarray, usm_ndarray}, optional
179+
If provided, the result will be placed in this array. It should
180+
be of the appropriate shape and dtype.
181+
Default: ``None``.
182+
183+
Returns
184+
-------
185+
out : dpnp.ndarray
186+
A copy of the slices of `a` where `condition` is True.
187+
188+
See also
189+
--------
190+
:obj:`dpnp.ndarray.compress` : Equivalent method.
191+
:obj:`dpnp.extract` : Equivalent function when working on 1-D arrays.
192+
"""
193+
dpnp.check_supported_arrays_type(a)
194+
if axis is None:
195+
if a.ndim != 1:
196+
a = dpnp.ravel(a)
197+
axis = 0
198+
else:
199+
axis = normalize_axis_index(operator.index(axis), a.ndim)
200+
201+
a_ary = dpnp.get_usm_ndarray(a)
202+
if not dpnp.is_supported_array_type(condition):
203+
usm_type = a_ary.usm_type
204+
q = a_ary.sycl_queue
205+
cond_ary = dpnp.as_usm_ndarray(
206+
condition,
207+
dtype=dpnp.bool,
208+
usm_type=usm_type,
209+
sycl_queue=q,
210+
)
211+
queues_ = [q]
212+
usm_types_ = [usm_type]
213+
else:
214+
cond_ary = dpnp.get_usm_ndarray(condition)
215+
queues_ = [a_ary.sycl_queue, cond_ary.sycl_queue]
216+
usm_types_ = [a_ary.usm_type, cond_ary.usm_type]
217+
if not cond_ary.ndim == 1:
218+
raise ValueError(
219+
"`condition` must be a 1-D array or un-nested " "sequence"
220+
)
221+
222+
res_usm_type = dpu.get_coerced_usm_type(usm_types_)
223+
exec_q = dpu.get_execution_queue(queues_)
224+
if exec_q is None:
225+
raise dpu.ExecutionPlacementError(
226+
"arrays must be allocated on the same SYCL queue"
227+
)
228+
229+
inds = _nonzero_impl(cond_ary) # synchronizes
230+
231+
res_dt = a_ary.dtype
232+
ind0 = inds[0]
233+
a_sh = a_ary.shape
234+
axis_end = axis + 1
235+
if 0 in a_sh[axis:axis_end] and ind0.size != 0:
236+
raise IndexError("cannot take non-empty indices from an empty axis")
237+
res_sh = a_sh[:axis] + ind0.shape + a_sh[axis_end:]
238+
239+
orig_out = out
240+
if out is not None:
241+
dpnp.check_supported_arrays_type(out)
242+
out = dpnp.get_usm_ndarray(out)
243+
244+
if not out.flags.writable:
245+
raise ValueError("provided `out` array is read-only")
246+
247+
if out.shape != res_sh:
248+
raise ValueError(
249+
"The shape of input and output arrays are inconsistent. "
250+
f"Expected output shape is {res_sh}, got {out.shape}"
251+
)
252+
253+
if res_dt != out.dtype:
254+
raise ValueError(
255+
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
256+
)
257+
258+
if dpu.get_execution_queue((a_ary.sycl_queue, out.sycl_queue)) is None:
259+
raise dpu.ExecutionPlacementError(
260+
"Input and output allocation queues are not compatible"
261+
)
262+
263+
if ti._array_overlap(a_ary, out):
264+
# Allocate a temporary buffer to avoid memory overlapping.
265+
out = dpt.empty_like(out)
266+
else:
267+
out = dpt.empty(
268+
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
269+
)
270+
271+
if out.size == 0:
272+
return out
273+
274+
_manager = dpu.SequentialOrderManager[exec_q]
275+
dep_evs = _manager.submitted_events
276+
277+
h_ev, take_ev = ti._take(
278+
src=a_ary,
279+
ind=inds,
280+
dst=out,
281+
axis_start=axis,
282+
mode=0,
283+
sycl_queue=exec_q,
284+
depends=dep_evs,
285+
)
286+
_manager.add_event_pair(h_ev, take_ev)
287+
288+
if not (orig_out is None or orig_out is out):
289+
# Copy the out data from temporary buffer to original memory
290+
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
291+
src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev]
292+
)
293+
_manager.add_event_pair(ht_copy_ev, cpy_ev)
294+
out = orig_out
295+
296+
return dpnp.get_result_array(out)
297+
298+
157299
def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):
158300
"""
159301
Return the indices to access the main diagonal of an array.

0 commit comments

Comments
 (0)