diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 54fd4ba2..48c1578a 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad, sinc +from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad, setdiff1d, sinc from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -11,7 +11,6 @@ default_dtype, kron, nunique, - setdiff1d, ) from ._lib._lazy import lazy_apply diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 26d8d0cd..5010bd18 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -416,6 +416,58 @@ def pad( return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) +def setdiff1d( + x1: Array | complex, + x2: Array | complex, + /, + *, + assume_unique: bool = False, + xp: ModuleType | None = None, +) -> Array: + """ + Find the set difference of two arrays. + + Return the unique values in `x1` that are not in `x2`. + + Parameters + ---------- + x1 : array | int | float | complex | bool + Input array. + x2 : array + Input comparison array. + assume_unique : bool + If ``True``, the input arrays are both assumed to be unique, which + can speed up the calculation. Default is ``False``. + xp : array_namespace, optional + The standard-compatible namespace for `x1` and `x2`. Default: infer. + + Returns + ------- + array + 1D array of values in `x1` that are not in `x2`. The result + is sorted when `assume_unique` is ``False``, but otherwise only sorted + if the input is sorted. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + + >>> x1 = xp.asarray([1, 2, 3, 2, 4, 1]) + >>> x2 = xp.asarray([3, 4, 5, 6]) + >>> xpx.setdiff1d(x1, x2, xp=xp) + Array([1, 2], dtype=array_api_strict.int64) + """ + + if xp is None: + xp = array_namespace(x1, x2) + + if is_numpy_namespace(xp) or is_jax_namespace(xp) or is_cupy_namespace(xp): + return xp.setdiff1d(x1, x2, assume_unique=assume_unique) + + return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp) + + def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: r""" Return the normalized sinc function. diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index b25e6e3e..22ad3c25 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -813,44 +813,10 @@ def setdiff1d( /, *, assume_unique: bool = False, - xp: ModuleType | None = None, -) -> Array: - """ - Find the set difference of two arrays. - - Return the unique values in `x1` that are not in `x2`. - - Parameters - ---------- - x1 : array | int | float | complex | bool - Input array. - x2 : array - Input comparison array. - assume_unique : bool - If ``True``, the input arrays are both assumed to be unique, which - can speed up the calculation. Default is ``False``. - xp : array_namespace, optional - The standard-compatible namespace for `x1` and `x2`. Default: infer. - - Returns - ------- - array - 1D array of values in `x1` that are not in `x2`. The result - is sorted when `assume_unique` is ``False``, but otherwise only sorted - if the input is sorted. - - Examples - -------- - >>> import array_api_strict as xp - >>> import array_api_extra as xpx + xp: ModuleType, +) -> Array: # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" - >>> x1 = xp.asarray([1, 2, 3, 2, 4, 1]) - >>> x2 = xp.asarray([3, 4, 5, 6]) - >>> xpx.setdiff1d(x1, x2, xp=xp) - Array([1, 2], dtype=array_api_strict.int64) - """ - if xp is None: - xp = array_namespace(x1, x2) # https://github.com/microsoft/pyright/issues/10103 x1_, x2_ = asarrays(x1, x2, xp=xp)