Skip to content

Commit f49888b

Browse files
committed
WIP ENH: in1d jax.jit support
1 parent 8e2e32c commit f49888b

File tree

6 files changed

+50
-19
lines changed

6 files changed

+50
-19
lines changed

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ reportUnknownVariableType = false
230230
# Redundant with mypy checks
231231
reportMissingImports = false
232232
reportMissingTypeStubs = false
233+
reportPossiblyUnboundVariable = false
233234
# false positives for input validation
234235
reportUnreachable = false
235236
# ruff handles this

src/array_api_extra/_lib/_funcs.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
14+
from ._utils._compat import array_namespace, is_jax_array, is_jax_namespace
1515
from ._utils._typing import Array
1616

1717
__all__ = [
@@ -533,6 +533,29 @@ def pad(
533533
return at(padded, tuple(slices)).set(x)
534534

535535

536+
def _ensure_unique_values(x: Array, assume_unique: bool, xp: ModuleType) -> Array:
537+
"""
538+
Wrapper around xp.unique_values
539+
540+
If x is a JAX array and we're running inside jax.jit, the output
541+
shape needs to be known. Return an array the size of x, padded
542+
with the first element of x.
543+
"""
544+
x = xp.reshape(x, (-1,))
545+
if assume_unique or x.shape == (0,):
546+
return x
547+
548+
if is_jax_array(x):
549+
import jax
550+
551+
try:
552+
return xp.unique_values(x) # eager
553+
except jax.errors.ConcretizationError: # inside jax.jit
554+
return xp.unique_values(x, size=x.size, fill_value=x[0])
555+
556+
return xp.unique_values(x)
557+
558+
536559
def setdiff1d(
537560
x1: Array,
538561
x2: Array,
@@ -578,12 +601,17 @@ def setdiff1d(
578601
if xp is None:
579602
xp = array_namespace(x1, x2)
580603

581-
if assume_unique:
582-
x1 = xp.reshape(x1, (-1,))
583-
else:
584-
x1 = xp.unique_values(x1)
585-
x2 = xp.unique_values(x2)
586-
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
604+
x1 = _ensure_unique_values(x1, assume_unique=assume_unique, xp=xp)
605+
x2 = _ensure_unique_values(x2, assume_unique=assume_unique, xp=xp)
606+
607+
if x1.shape == (0,) or x2.shape == (0,):
608+
return x1
609+
610+
mask = _helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)
611+
if is_jax_namespace(xp):
612+
fill = xp.zeros((), dtype=x1.dtype, device=_compat.device(x1))
613+
return xp.where(mask, fill, x1)
614+
return x1[mask]
587615

588616

589617
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from types import ModuleType
77

88
from . import _compat
9+
from ._compat import is_jax_namespace
910
from ._typing import Array
1011

1112
__all__ = ["in1d", "mean"]
@@ -45,13 +46,15 @@ def in1d(
4546
mask |= x1 == a
4647
return mask
4748

48-
rev_idx = xp.empty(0) # placeholder
4949
if not assume_unique:
50-
x1, rev_idx = xp.unique_inverse(x1)
51-
x2 = xp.unique_values(x2)
50+
if is_jax_namespace(xp):
51+
x1, rev_idx = xp.unique_inverse(x1, size=x1.size, fill_value=x1[0])
52+
x2 = xp.unique_values(x2, size=x2.size, fill_value=x2[0])
53+
else:
54+
x1, rev_idx = xp.unique_inverse(x1)
55+
x2 = xp.unique_values(x2)
5256

5357
ar = xp.concat((x1, x2))
54-
device_ = _compat.device(ar)
5558
# We need this to be a stable sort.
5659
order = xp.argsort(ar, stable=True)
5760
reverse_order = xp.argsort(order, stable=True)
@@ -62,12 +65,12 @@ def in1d(
6265
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
6366
else:
6467
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
65-
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
68+
flag = xp.concat((bool_ar, xp.asarray([invert], device=_compat.device(ar))))
6669
ret = xp.take(flag, reverse_order, axis=0)
6770

68-
if assume_unique:
69-
return ret[: x1.shape[0]]
70-
return xp.take(ret, rev_idx, axis=0)
71+
if not assume_unique:
72+
return xp.take(ret, rev_idx, axis=0) # type: ignore[possibly-undefined]
73+
return ret[: x1.shape[0]]
7174

7275

7376
def mean(

tests/test_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
lazy_xp_function(kron, static_argnames="xp")
3535
lazy_xp_function(nunique, static_argnames="xp")
3636
lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp"))
37-
# FIXME calls in1d which calls xp.unique_values without size
37+
lazy_xp_function(setdiff1d, static_argnames=("assume_unique", "xp"))
3838
lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp"))
3939
lazy_xp_function(sinc, static_argnames="xp")
4040

tests/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
# mypy: disable-error-code=no-untyped-usage
1313

14-
# FIXME calls xp.unique_values without size
15-
lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))
14+
lazy_xp_function(in1d, static_argnames=("assume_unique", "invert", "xp"))
1615

1716

1817
class TestIn1D:

0 commit comments

Comments
 (0)