Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad, sinc
from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad, setdiff1d, sinc
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand All @@ -11,7 +11,6 @@
default_dtype,
kron,
nunique,
setdiff1d,
)
from ._lib._lazy import lazy_apply

Expand Down
52 changes: 52 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,58 @@ def pad(
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def setdiff1d(
x1: Array | complex,
x2: Array | complex,
/,
*,
assume_unique: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Find the set difference of two arrays.

Return the unique values in `x1` that are not in `x2`.

Parameters
----------
x1 : array | int | float | complex | bool
Input array.
x2 : array
Input comparison array.
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.

Returns
-------
array
1D array of values in `x1` that are not in `x2`. The result
is sorted when `assume_unique` is ``False``, but otherwise only sorted
if the input is sorted.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx

>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
>>> x2 = xp.asarray([3, 4, 5, 6])
>>> xpx.setdiff1d(x1, x2, xp=xp)
Array([1, 2], dtype=array_api_strict.int64)
"""

if xp is None:
xp = array_namespace(x1, x2)

if is_numpy_namespace(xp) or is_jax_namespace(xp) or is_cupy_namespace(xp):
return xp.setdiff1d(x1, x2, assume_unique=assume_unique)

return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp)


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
r"""
Return the normalized sinc function.
Expand Down
40 changes: 3 additions & 37 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,44 +813,10 @@ def setdiff1d(
/,
*,
assume_unique: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Find the set difference of two arrays.

Return the unique values in `x1` that are not in `x2`.

Parameters
----------
x1 : array | int | float | complex | bool
Input array.
x2 : array
Input comparison array.
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.

Returns
-------
array
1D array of values in `x1` that are not in `x2`. The result
is sorted when `assume_unique` is ``False``, but otherwise only sorted
if the input is sorted.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""

>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
>>> x2 = xp.asarray([3, 4, 5, 6])
>>> xpx.setdiff1d(x1, x2, xp=xp)
Array([1, 2], dtype=array_api_strict.int64)
"""
if xp is None:
xp = array_namespace(x1, x2)
# https://github.com/microsoft/pyright/issues/10103
x1_, x2_ = asarrays(x1, x2, xp=xp)

Expand Down
Loading