Skip to content

Commit 0b905e9

Browse files
authored
Add assume_unique and invert arguments to keras.ops.isin (#21552)
* Add arguments for isin method * Add test cases * Update argument description * Add test cases more
1 parent 7e0705a commit 0b905e9

File tree

7 files changed

+91
-12
lines changed

7 files changed

+91
-12
lines changed

keras/src/backend/jax/numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,10 +769,10 @@ def isfinite(x):
769769
return jnp.isfinite(x)
770770

771771

772-
def isin(x1, x2):
772+
def isin(x1, x2, assume_unique=False, invert=False):
773773
x1 = convert_to_tensor(x1)
774774
x2 = convert_to_tensor(x2)
775-
return jnp.isin(x1, x2)
775+
return jnp.isin(x1, x2, assume_unique=assume_unique, invert=invert)
776776

777777

778778
@sparse.elementwise_unary(linear=False)

keras/src/backend/numpy/numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,10 +682,10 @@ def isfinite(x):
682682
return np.isfinite(x)
683683

684684

685-
def isin(x1, x2):
685+
def isin(x1, x2, assume_unique=False, invert=False):
686686
x1 = convert_to_tensor(x1)
687687
x2 = convert_to_tensor(x2)
688-
return np.isin(x1, x2)
688+
return np.isin(x1, x2, assume_unique=assume_unique, invert=invert)
689689

690690

691691
def isinf(x):

keras/src/backend/openvino/numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ def isfinite(x):
929929
return OpenVINOKerasTensor(ov_opset.is_finite(x).output(0))
930930

931931

932-
def isin(x1, x2):
932+
def isin(x1, x2, assume_unique=False, invert=False):
933933
raise NotImplementedError("`isin` is not supported with openvino backend")
934934

935935

keras/src/backend/tensorflow/numpy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1592,7 +1592,7 @@ def isfinite(x):
15921592
return tf.math.is_finite(x)
15931593

15941594

1595-
def isin(x1, x2):
1595+
def isin(x1, x2, assume_unique=False, invert=False):
15961596
x1 = convert_to_tensor(x1)
15971597
x2 = convert_to_tensor(x2)
15981598

@@ -1605,12 +1605,18 @@ def isin(x1, x2):
16051605
x1 = tf.reshape(x1, [-1])
16061606
x2 = tf.reshape(x2, [-1])
16071607

1608+
if not assume_unique:
1609+
x2 = tf.unique(x2)[0]
1610+
16081611
if tf.size(x1) == 0 or tf.size(x2) == 0:
16091612
return tf.zeros(output_shape, dtype=tf.bool)
16101613

16111614
cmp = tf.equal(tf.expand_dims(x1, 1), tf.expand_dims(x2, 0))
16121615
result_flat = tf.reduce_any(cmp, axis=1)
16131616

1617+
if invert:
1618+
result_flat = tf.logical_not(result_flat)
1619+
16141620
return tf.reshape(result_flat, output_shape)
16151621

16161622

keras/src/backend/torch/numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,7 @@ def isfinite(x):
886886
return torch.isfinite(x)
887887

888888

889-
def isin(x1, x2):
889+
def isin(x1, x2, assume_unique=False, invert=False):
890890
x1 = convert_to_tensor(x1)
891891
x2 = convert_to_tensor(x2)
892892

@@ -900,7 +900,7 @@ def isin(x1, x2):
900900
if standardize_dtype(x2.dtype) == "bool":
901901
x2 = cast(x2, x1.dtype)
902902

903-
return torch.isin(x1, x2)
903+
return torch.isin(x1, x2, assume_unique=assume_unique, invert=invert)
904904

905905

906906
def isinf(x):

keras/src/ops/numpy.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3617,15 +3617,28 @@ def isfinite(x):
36173617

36183618

36193619
class IsIn(Operation):
3620+
def __init__(
3621+
self,
3622+
assume_unique=False,
3623+
invert=False,
3624+
*,
3625+
name=None,
3626+
):
3627+
super().__init__(name=name)
3628+
self.assume_unique = assume_unique
3629+
self.invert = invert
3630+
36203631
def call(self, x1, x2):
3621-
return backend.numpy.isin(x1, x2)
3632+
return backend.numpy.isin(
3633+
x1, x2, assume_unique=self.assume_unique, invert=self.invert
3634+
)
36223635

36233636
def compute_output_spec(self, x1, x2):
36243637
return KerasTensor(x1.shape, dtype="bool")
36253638

36263639

36273640
@keras_export(["keras.ops.isin", "keras.ops.numpy.isin"])
3628-
def isin(x1, x2):
3641+
def isin(x1, x2, assume_unique=False, invert=False):
36293642
"""Test whether each element of `x1` is present in `x2`.
36303643
36313644
This operation performs element-wise checks to determine if each value
@@ -3637,6 +3650,13 @@ def isin(x1, x2):
36373650
x1: Input tensor or array-like structure to test.
36383651
x2: Values against which each element of `x1` is tested.
36393652
Can be a tensor, list, or scalar.
3653+
assume_unique: Boolean (default: False).
3654+
If True, assumes both `x1` and `x2` contain only unique elements.
3655+
This can speed up the computation. If False, duplicates will be
3656+
handled correctly but may impact performance.
3657+
invert: A boolean (default: False).
3658+
If True, inverts the result. Entries will be `True`
3659+
where `x1` elements are not in `x2`.
36403660
36413661
Returns:
36423662
A boolean tensor of the same shape as `x1` indicating element-wise
@@ -3650,8 +3670,12 @@ def isin(x1, x2):
36503670
array([ True, False, True, False])
36513671
"""
36523672
if any_symbolic_tensors((x1, x2)):
3653-
return IsIn().symbolic_call(x1, x2)
3654-
return backend.numpy.isin(x1, x2)
3673+
return IsIn(assume_unique=assume_unique, invert=invert).symbolic_call(
3674+
x1, x2
3675+
)
3676+
return backend.numpy.isin(
3677+
x1, x2, assume_unique=assume_unique, invert=invert
3678+
)
36553679

36563680

36573681
class Isinf(Operation):

keras/src/ops/numpy_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2880,10 +2880,59 @@ def test_isin(self):
28802880
self.assertAllClose(knp.isin(x, 2), np.isin(x, 2))
28812881
self.assertAllClose(knp.isin(2, x), np.isin(2, x))
28822882

2883+
self.assertAllClose(
2884+
knp.isin(x, y, assume_unique=True),
2885+
np.isin(x, y, assume_unique=True),
2886+
)
2887+
self.assertAllClose(
2888+
knp.isin(x, 2, assume_unique=True),
2889+
np.isin(x, 2, assume_unique=True),
2890+
)
2891+
self.assertAllClose(
2892+
knp.isin(2, x, assume_unique=True),
2893+
np.isin(2, x, assume_unique=True),
2894+
)
2895+
2896+
self.assertAllClose(
2897+
knp.isin(x, y, invert=True), np.isin(x, y, invert=True)
2898+
)
2899+
self.assertAllClose(
2900+
knp.isin(x, 2, invert=True), np.isin(x, 2, invert=True)
2901+
)
2902+
self.assertAllClose(
2903+
knp.isin(2, x, invert=True), np.isin(2, x, invert=True)
2904+
)
2905+
2906+
self.assertAllClose(
2907+
knp.isin(x, y, assume_unique=True, invert=True),
2908+
np.isin(x, y, assume_unique=True, invert=True),
2909+
)
2910+
self.assertAllClose(
2911+
knp.isin(x, 2, assume_unique=True, invert=True),
2912+
np.isin(x, 2, assume_unique=True, invert=True),
2913+
)
2914+
self.assertAllClose(
2915+
knp.isin(2, x, assume_unique=True, invert=True),
2916+
np.isin(2, x, assume_unique=True, invert=True),
2917+
)
2918+
28832919
self.assertAllClose(knp.IsIn()(x, y), np.isin(x, y))
28842920
self.assertAllClose(knp.IsIn()(x, 2), np.isin(x, 2))
28852921
self.assertAllClose(knp.IsIn()(2, x), np.isin(2, x))
28862922

2923+
self.assertAllClose(
2924+
knp.IsIn(assume_unique=True)(x, y),
2925+
np.isin(x, y, assume_unique=True),
2926+
)
2927+
self.assertAllClose(
2928+
knp.IsIn(invert=True)(x, y),
2929+
np.isin(x, y, invert=True),
2930+
)
2931+
self.assertAllClose(
2932+
knp.IsIn(assume_unique=True, invert=True)(x, y),
2933+
np.isin(x, y, assume_unique=True, invert=True),
2934+
)
2935+
28872936
def test_less(self):
28882937
x = np.array([[1, 2, 3], [3, 2, 1]])
28892938
y = np.array([[4, 5, 6], [3, 2, 1]])

0 commit comments

Comments
 (0)