Skip to content

Commit 92aefef

Browse files
authored
Implement isin function in keras.ops (#21523)
* Add isin method for numpy * Add isin method for openvino * Correct logic for tensorflow * Add isin method for ops * Add test cases and update description * update excluded_concrete_tests.txt * correct isin method for tensorflow
1 parent 3744538 commit 92aefef

File tree

12 files changed

+146
-0
lines changed

12 files changed

+146
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@
199199
from keras.src.ops.numpy import inner as inner
200200
from keras.src.ops.numpy import isclose as isclose
201201
from keras.src.ops.numpy import isfinite as isfinite
202+
from keras.src.ops.numpy import isin as isin
202203
from keras.src.ops.numpy import isinf as isinf
203204
from keras.src.ops.numpy import isnan as isnan
204205
from keras.src.ops.numpy import kaiser as kaiser

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
from keras.src.ops.numpy import inner as inner
8989
from keras.src.ops.numpy import isclose as isclose
9090
from keras.src.ops.numpy import isfinite as isfinite
91+
from keras.src.ops.numpy import isin as isin
9192
from keras.src.ops.numpy import isinf as isinf
9293
from keras.src.ops.numpy import isnan as isnan
9394
from keras.src.ops.numpy import kaiser as kaiser

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@
199199
from keras.src.ops.numpy import inner as inner
200200
from keras.src.ops.numpy import isclose as isclose
201201
from keras.src.ops.numpy import isfinite as isfinite
202+
from keras.src.ops.numpy import isin as isin
202203
from keras.src.ops.numpy import isinf as isinf
203204
from keras.src.ops.numpy import isnan as isnan
204205
from keras.src.ops.numpy import kaiser as kaiser

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
from keras.src.ops.numpy import inner as inner
8989
from keras.src.ops.numpy import isclose as isclose
9090
from keras.src.ops.numpy import isfinite as isfinite
91+
from keras.src.ops.numpy import isin as isin
9192
from keras.src.ops.numpy import isinf as isinf
9293
from keras.src.ops.numpy import isnan as isnan
9394
from keras.src.ops.numpy import kaiser as kaiser

keras/src/backend/jax/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,12 @@ def isfinite(x):
769769
return jnp.isfinite(x)
770770

771771

772+
def isin(x1, x2):
773+
x1 = convert_to_tensor(x1)
774+
x2 = convert_to_tensor(x2)
775+
return jnp.isin(x1, x2)
776+
777+
772778
@sparse.elementwise_unary(linear=False)
773779
def isinf(x):
774780
x = convert_to_tensor(x)

keras/src/backend/numpy/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,12 @@ def isfinite(x):
682682
return np.isfinite(x)
683683

684684

685+
def isin(x1, x2):
686+
x1 = convert_to_tensor(x1)
687+
x2 = convert_to_tensor(x2)
688+
return np.isin(x1, x2)
689+
690+
685691
def isinf(x):
686692
return np.isinf(x)
687693

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ NumpyDtypeTest::test_flip
3232
NumpyDtypeTest::test_floor
3333
NumpyDtypeTest::test_inner
3434
NumpyDtypeTest::test_isfinite
35+
NumpyDtypeTest::test_isin
3536
NumpyDtypeTest::test_isinf
3637
NumpyDtypeTest::test_isnan
3738
NumpyDtypeTest::test_linspace
@@ -142,6 +143,7 @@ NumpyTwoInputOpsCorrectnessTest::test_divide_no_nan
142143
NumpyTwoInputOpsCorrectnessTest::test_einsum
143144
NumpyTwoInputOpsCorrectnessTest::test_heaviside
144145
NumpyTwoInputOpsCorrectnessTest::test_inner
146+
NumpyTwoInputOpsCorrectnessTest::test_isin
145147
NumpyTwoInputOpsCorrectnessTest::test_linspace
146148
NumpyTwoInputOpsCorrectnessTest::test_logspace
147149
NumpyTwoInputOpsCorrectnessTest::test_quantile
@@ -160,7 +162,9 @@ NumpyOneInputOpsStaticShapeTest::test_angle
160162
NumpyOneInputOpsStaticShapeTest::test_cbrt
161163
NumpyOneInputOpsStaticShapeTest::test_deg2rad
162164
NumpyTwoInputOpsDynamicShapeTest::test_heaviside
165+
NumpyTwoInputOpsDynamicShapeTest::test_isin
163166
NumpyTwoInputOpsStaticShapeTest::test_heaviside
167+
NumpyTwoInputOpsStaticShapeTest::test_isin
164168
CoreOpsBehaviorTests::test_associative_scan_invalid_arguments
165169
CoreOpsBehaviorTests::test_scan_invalid_arguments
166170
CoreOpsCallsTests::test_associative_scan_basic_call

keras/src/backend/openvino/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,10 @@ def isfinite(x):
929929
return OpenVINOKerasTensor(ov_opset.is_finite(x).output(0))
930930

931931

932+
def isin(x1, x2):
933+
raise NotImplementedError("`isin` is not supported with openvino backend")
934+
935+
932936
def isinf(x):
933937
x = get_ov_output(x)
934938
return OpenVINOKerasTensor(ov_opset.is_inf(x).output(0))

keras/src/backend/tensorflow/numpy.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,28 @@ def isfinite(x):
15921592
return tf.math.is_finite(x)
15931593

15941594

1595+
def isin(x1, x2):
1596+
x1 = convert_to_tensor(x1)
1597+
x2 = convert_to_tensor(x2)
1598+
1599+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
1600+
x1 = tf.cast(x1, dtype)
1601+
x2 = tf.cast(x2, dtype)
1602+
1603+
output_shape = tf.shape(x1)
1604+
1605+
x1 = tf.reshape(x1, [-1])
1606+
x2 = tf.reshape(x2, [-1])
1607+
1608+
if tf.size(x1) == 0 or tf.size(x2) == 0:
1609+
return tf.zeros(output_shape, dtype=tf.bool)
1610+
1611+
cmp = tf.equal(tf.expand_dims(x1, 1), tf.expand_dims(x2, 0))
1612+
result_flat = tf.reduce_any(cmp, axis=1)
1613+
1614+
return tf.reshape(result_flat, output_shape)
1615+
1616+
15951617
def isinf(x):
15961618
x = convert_to_tensor(x)
15971619
dtype_as_dtype = tf.as_dtype(x.dtype)

keras/src/backend/torch/numpy.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,23 @@ def isfinite(x):
886886
return torch.isfinite(x)
887887

888888

889+
def isin(x1, x2):
890+
x1 = convert_to_tensor(x1)
891+
x2 = convert_to_tensor(x2)
892+
893+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
894+
if dtype == "bool":
895+
x1 = cast(x1, "int32")
896+
x2 = cast(x2, "int32")
897+
898+
if standardize_dtype(x1.dtype) == "bool":
899+
x1 = cast(x1, x2.dtype)
900+
if standardize_dtype(x2.dtype) == "bool":
901+
x2 = cast(x2, x1.dtype)
902+
903+
return torch.isin(x1, x2)
904+
905+
889906
def isinf(x):
890907
x = convert_to_tensor(x)
891908
return torch.isinf(x)

0 commit comments

Comments
 (0)