Skip to content

Commit a716647

Browse files
committed
WIP
1 parent feca84d commit a716647

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
14+
from ._utils._compat import array_namespace, is_jax_array, is_jax_namespace
1515
from ._utils._typing import Array
1616

1717
__all__ = [
@@ -533,6 +533,29 @@ def pad(
533533
return at(padded, tuple(slices)).set(x)
534534

535535

536+
def _ensure_unique_values(x: Array, assume_unique: bool, xp: ModuleType) -> Array:
537+
"""
538+
Wrapper around xp.unique_values
539+
540+
If x is a JAX array and we're running inside jax.jit, the output
541+
shape needs to be known. Return an array the size of x, padded
542+
with the first element of x.
543+
"""
544+
x = xp.reshape(x, (-1,))
545+
if assume_unique or x.shape == (0, ):
546+
return x
547+
548+
if is_jax_array(x):
549+
import jax
550+
551+
try:
552+
return xp.unique_values(x) # eager
553+
except jax.errors.ConcretizationError: # inside jax.jit
554+
return xp.unique_values(x, size=x.size, fill_value=x[0])
555+
556+
return xp.unique_values(x)
557+
558+
536559
def setdiff1d(
537560
x1: Array,
538561
x2: Array,
@@ -578,12 +601,17 @@ def setdiff1d(
578601
if xp is None:
579602
xp = array_namespace(x1, x2)
580603

581-
if assume_unique:
582-
x1 = xp.reshape(x1, (-1,))
583-
else:
584-
x1 = xp.unique_values(x1)
585-
x2 = xp.unique_values(x2)
586-
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
604+
x1 = _ensure_unique_values(x1, assume_unique=assume_unique, xp=xp)
605+
x2 = _ensure_unique_values(x2, assume_unique=assume_unique, xp=xp)
606+
607+
if x1.shape == (0,) or x2.shape == (0, ):
608+
return x1
609+
610+
mask = _helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)
611+
if is_jax_namespace(xp):
612+
fill = xp.zeros((), dtype=x1.dtype, device=_compat.device(x1))
613+
return xp.where(mask, fill, x1)
614+
return x1[mask]
587615

588616

589617
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from types import ModuleType
77

88
from . import _compat
9+
from ._compat import is_jax_namespace
910
from ._typing import Array
1011

1112
__all__ = ["in1d", "mean"]
@@ -46,11 +47,14 @@ def in1d(
4647
return mask
4748

4849
if not assume_unique:
49-
x1, rev_idx = xp.unique_inverse(x1)
50-
x2 = xp.unique_values(x2)
50+
if is_jax_namespace(xp):
51+
x1, rev_idx = xp.unique_inverse(x1, size=x1.size, fill_value=x1[0])
52+
x2 = xp.unique_values(x2, size=x2.size, fill_value=x2[0])
53+
else:
54+
x1, rev_idx = xp.unique_inverse(x1)
55+
x2 = xp.unique_values(x2)
5156

5257
ar = xp.concat((x1, x2))
53-
device_ = _compat.device(ar)
5458
# We need this to be a stable sort.
5559
order = xp.argsort(ar, stable=True)
5660
reverse_order = xp.argsort(order, stable=True)
@@ -61,7 +65,7 @@ def in1d(
6165
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
6266
else:
6367
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
64-
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
68+
flag = xp.concat((bool_ar, xp.asarray([invert], device=_compat.device(ar))))
6569
ret = xp.take(flag, reverse_order, axis=0)
6670

6771
if not assume_unique:

0 commit comments

Comments
 (0)