|
11 | 11 |
|
12 | 12 | from ._at import at |
13 | 13 | from ._utils import _compat, _helpers |
14 | | -from ._utils._compat import array_namespace, is_jax_array |
| 14 | +from ._utils._compat import array_namespace, is_jax_array, is_jax_namespace |
15 | 15 | from ._utils._typing import Array |
16 | 16 |
|
17 | 17 | __all__ = [ |
@@ -533,6 +533,29 @@ def pad( |
533 | 533 | return at(padded, tuple(slices)).set(x) |
534 | 534 |
|
535 | 535 |
|
| 536 | +def _ensure_unique_values(x: Array, assume_unique: bool, xp: ModuleType) -> Array: |
| 537 | + """ |
| 538 | + Wrapper around xp.unique_values |
| 539 | +
|
| 540 | + If x is a JAX array and we're running inside jax.jit, the output |
| 541 | + shape needs to be known. Return an array the size of x, padded |
| 542 | + with the first element of x. |
| 543 | + """ |
| 544 | + x = xp.reshape(x, (-1,)) |
| 545 | + if assume_unique or x.shape == (0, ): |
| 546 | + return x |
| 547 | + |
| 548 | + if is_jax_array(x): |
| 549 | + import jax |
| 550 | + |
| 551 | + try: |
| 552 | + return xp.unique_values(x) # eager |
| 553 | + except jax.errors.ConcretizationError: # inside jax.jit |
| 554 | + return xp.unique_values(x, size=x.size, fill_value=x[0]) |
| 555 | + |
| 556 | + return xp.unique_values(x) |
| 557 | + |
| 558 | + |
536 | 559 | def setdiff1d( |
537 | 560 | x1: Array, |
538 | 561 | x2: Array, |
@@ -578,12 +601,17 @@ def setdiff1d( |
578 | 601 | if xp is None: |
579 | 602 | xp = array_namespace(x1, x2) |
580 | 603 |
|
581 | | - if assume_unique: |
582 | | - x1 = xp.reshape(x1, (-1,)) |
583 | | - else: |
584 | | - x1 = xp.unique_values(x1) |
585 | | - x2 = xp.unique_values(x2) |
586 | | - return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] |
| 604 | + x1 = _ensure_unique_values(x1, assume_unique=assume_unique, xp=xp) |
| 605 | + x2 = _ensure_unique_values(x2, assume_unique=assume_unique, xp=xp) |
| 606 | + |
| 607 | + if x1.shape == (0,) or x2.shape == (0, ): |
| 608 | + return x1 |
| 609 | + |
| 610 | + mask = _helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp) |
| 611 | + if is_jax_namespace(xp): |
| 612 | + fill = xp.zeros((), dtype=x1.dtype, device=_compat.device(x1)) |
| 613 | + return xp.where(mask, fill, x1) |
| 614 | + return x1[mask] |
587 | 615 |
|
588 | 616 |
|
589 | 617 | def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
|
0 commit comments