Skip to content

Commit 740b08b

Browse files
authored
Update dpnp.extract implementation to get rid of limitations for input arguments (#1906)
* Remove limitations from dpnp.extract implementation * Add more tests * Tune rtol and atol for a histogram test, since might fail on Windows * Fix a typo in description * Add test to cover condition as list
1 parent 03b585b commit 740b08b

File tree

9 files changed

+276
-120
lines changed

9 files changed

+276
-120
lines changed

doc/reference/sorting.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ Searching
3131
dpnp.nanargmax
3232
dpnp.argmin
3333
dpnp.nanargmin
34+
dpnp.argwhere
3435
dpnp.nonzero
3536
dpnp.flatnonzero
3637
dpnp.where
37-
dpnp.argwhere
3838
dpnp.searchsorted
3939
dpnp.extract
4040

dpnp/dpnp_iface_indexing.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -490,42 +490,86 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
490490
)
491491

492492

493-
def extract(condition, x):
493+
def extract(condition, a):
494494
"""
495495
Return the elements of an array that satisfy some condition.
496496
497+
This is equivalent to
498+
``dpnp.compress(dpnp.ravel(condition), dpnp.ravel(a))``. If `condition`
499+
is boolean :obj:`dpnp.extract` is equivalent to ``a[condition]``.
500+
501+
Note that :obj:`dpnp.place` does the exact opposite of :obj:`dpnp.extract`.
502+
497503
For full documentation refer to :obj:`numpy.extract`.
498504
505+
Parameters
506+
----------
507+
condition : {array_like, scalar}
508+
An array whose non-zero or ``True`` entries indicate the element of `a`
509+
to extract.
510+
a : {dpnp_array, usm_ndarray}
511+
Input array of the same size as `condition`.
512+
499513
Returns
500514
-------
501515
out : dpnp.ndarray
502-
Rank 1 array of values from `x` where `condition` is True.
516+
Rank 1 array of values from `a` where `condition` is ``True``.
517+
518+
See Also
519+
--------
520+
:obj:`dpnp.take` : Take elements from an array along an axis.
521+
:obj:`dpnp.put` : Replaces specified elements of an array with given values.
522+
:obj:`dpnp.copyto` : Copies values from one array to another, broadcasting
523+
as necessary.
524+
:obj:`dpnp.compress` : eturn selected slices of an array along given axis.
525+
:obj:`dpnp.place` : Change elements of an array based on conditional and
526+
input values.
527+
528+
Examples
529+
--------
530+
>>> import dpnp as np
531+
>>> a = np.arange(12).reshape((3, 4))
532+
>>> a
533+
array([[ 0, 1, 2, 3],
534+
[ 4, 5, 6, 7],
535+
[ 8, 9, 10, 11]])
536+
>>> condition = np.mod(a, 3) == 0
537+
>>> condition
538+
array([[ True, False, False, True],
539+
[False, False, True, False],
540+
[False, True, False, False]])
541+
>>> np.extract(condition, a)
542+
array([0, 3, 6, 9])
543+
544+
If `condition` is boolean:
545+
546+
>>> a[condition]
547+
array([0, 3, 6, 9])
503548
504-
Limitations
505-
-----------
506-
Parameters `condition` and `x` are supported either as
507-
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
508-
Parameter `x` must be the same shape as `condition`.
509-
Otherwise the function will be executed sequentially on CPU.
510549
"""
511550

512-
if dpnp.is_supported_array_type(condition) and dpnp.is_supported_array_type(
513-
x
514-
):
515-
if condition.shape != x.shape:
516-
pass
517-
else:
518-
dpt_condition = (
519-
condition.get_array()
520-
if isinstance(condition, dpnp_array)
521-
else condition
522-
)
523-
dpt_array = x.get_array() if isinstance(x, dpnp_array) else x
524-
return dpnp_array._create_from_usm_ndarray(
525-
dpt.extract(dpt_condition, dpt_array)
526-
)
551+
usm_a = dpnp.get_usm_ndarray(a)
552+
if not dpnp.is_supported_array_type(condition):
553+
usm_cond = dpt.asarray(
554+
condition, usm_type=a.usm_type, sycl_queue=a.sycl_queue
555+
)
556+
else:
557+
usm_cond = dpnp.get_usm_ndarray(condition)
558+
559+
if usm_cond.size != usm_a.size:
560+
usm_a = dpt.reshape(usm_a, -1)
561+
usm_cond = dpt.reshape(usm_cond, -1)
562+
563+
usm_res = dpt.take(usm_a, dpt.nonzero(usm_cond)[0])
564+
else:
565+
if usm_cond.shape != usm_a.shape:
566+
usm_a = dpt.reshape(usm_a, -1)
567+
usm_cond = dpt.reshape(usm_cond, -1)
568+
569+
usm_res = dpt.extract(usm_cond, usm_a)
527570

528-
return call_origin(numpy.extract, condition, x)
571+
dpnp.synchronize_array_data(usm_res)
572+
return dpnp_array._create_from_usm_ndarray(usm_res)
529573

530574

531575
def fill_diagonal(a, val, wrap=False):

tests/skipped_tests.tbl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i
124124
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
125125
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order
126126

127-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress
128-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim
129-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
130-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
131-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
132127
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
133128
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
134129
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i
174174
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
175175
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order
176176

177-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress
178-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim
179-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
180-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
181-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
182177
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
183178
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
184179
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast

tests/test_histogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_density(self, dtype):
182182
result_hist, result_edges = dpnp.histogram(iv, density=True)
183183

184184
if numpy.issubdtype(dtype, numpy.inexact):
185-
tol = numpy.finfo(dtype).resolution
185+
tol = 4 * numpy.finfo(dtype).resolution
186186
assert_allclose(result_hist, expected_hist, rtol=tol, atol=tol)
187187
assert_allclose(result_edges, expected_edges, rtol=tol, atol=tol)
188188
else:

0 commit comments

Comments
 (0)