Skip to content

Commit 6efe3b9

Browse files
committed
ENH: setdiff1d delegate function
1 parent 747f994 commit 6efe3b9

File tree

3 files changed

+56
-39
lines changed

3 files changed

+56
-39
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad, sinc
3+
from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad, setdiff1d, sinc
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -11,7 +11,6 @@
1111
default_dtype,
1212
kron,
1313
nunique,
14-
setdiff1d,
1514
)
1615
from ._lib._lazy import lazy_apply
1716

src/array_api_extra/_delegation.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,58 @@ def pad(
416416
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
417417

418418

419+
def setdiff1d(
420+
x1: Array | complex,
421+
x2: Array | complex,
422+
/,
423+
*,
424+
assume_unique: bool = False,
425+
xp: ModuleType | None = None,
426+
) -> Array:
427+
"""
428+
Find the set difference of two arrays.
429+
430+
Return the unique values in `x1` that are not in `x2`.
431+
432+
Parameters
433+
----------
434+
x1 : array | int | float | complex | bool
435+
Input array.
436+
x2 : array
437+
Input comparison array.
438+
assume_unique : bool
439+
If ``True``, the input arrays are both assumed to be unique, which
440+
can speed up the calculation. Default is ``False``.
441+
xp : array_namespace, optional
442+
The standard-compatible namespace for `x1` and `x2`. Default: infer.
443+
444+
Returns
445+
-------
446+
array
447+
1D array of values in `x1` that are not in `x2`. The result
448+
is sorted when `assume_unique` is ``False``, but otherwise only sorted
449+
if the input is sorted.
450+
451+
Examples
452+
--------
453+
>>> import array_api_strict as xp
454+
>>> import array_api_extra as xpx
455+
456+
>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
457+
>>> x2 = xp.asarray([3, 4, 5, 6])
458+
>>> xpx.setdiff1d(x1, x2, xp=xp)
459+
Array([1, 2], dtype=array_api_strict.int64)
460+
"""
461+
462+
if xp is None:
463+
xp = array_namespace(x1, x2)
464+
465+
if is_numpy_namespace(xp) or is_jax_namespace(xp) or is_cupy_namespace(xp):
466+
return xp.setdiff1d(x1, x2, assume_unique=assume_unique)
467+
468+
return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp)
469+
470+
419471
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
420472
r"""
421473
Return the normalized sinc function.

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -813,44 +813,10 @@ def setdiff1d(
813813
/,
814814
*,
815815
assume_unique: bool = False,
816-
xp: ModuleType | None = None,
817-
) -> Array:
818-
"""
819-
Find the set difference of two arrays.
820-
821-
Return the unique values in `x1` that are not in `x2`.
822-
823-
Parameters
824-
----------
825-
x1 : array | int | float | complex | bool
826-
Input array.
827-
x2 : array
828-
Input comparison array.
829-
assume_unique : bool
830-
If ``True``, the input arrays are both assumed to be unique, which
831-
can speed up the calculation. Default is ``False``.
832-
xp : array_namespace, optional
833-
The standard-compatible namespace for `x1` and `x2`. Default: infer.
834-
835-
Returns
836-
-------
837-
array
838-
1D array of values in `x1` that are not in `x2`. The result
839-
is sorted when `assume_unique` is ``False``, but otherwise only sorted
840-
if the input is sorted.
841-
842-
Examples
843-
--------
844-
>>> import array_api_strict as xp
845-
>>> import array_api_extra as xpx
816+
xp: ModuleType,
817+
) -> Array: # numpydoc ignore=PR01,RT01
818+
"""See docstring in `array_api_extra._delegation.py`."""
846819

847-
>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
848-
>>> x2 = xp.asarray([3, 4, 5, 6])
849-
>>> xpx.setdiff1d(x1, x2, xp=xp)
850-
Array([1, 2], dtype=array_api_strict.int64)
851-
"""
852-
if xp is None:
853-
xp = array_namespace(x1, x2)
854820
# https://github.com/microsoft/pyright/issues/10103
855821
x1_, x2_ = asarrays(x1, x2, xp=xp)
856822

0 commit comments

Comments
 (0)