Skip to content

Commit 124a258

Browse files
authored
Implement gcd function in keras.ops (#21623)
* Add gcd for keras.ops * Update test cases for gcd * Update tensorflow implementation by gemini review
1 parent aad4172 commit 124a258

File tree

12 files changed

+150
-0
lines changed

12 files changed

+150
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@
187187
from keras.src.ops.numpy import floor_divide as floor_divide
188188
from keras.src.ops.numpy import full as full
189189
from keras.src.ops.numpy import full_like as full_like
190+
from keras.src.ops.numpy import gcd as gcd
190191
from keras.src.ops.numpy import get_item as get_item
191192
from keras.src.ops.numpy import greater as greater
192193
from keras.src.ops.numpy import greater_equal as greater_equal

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
from keras.src.ops.numpy import floor_divide as floor_divide
7676
from keras.src.ops.numpy import full as full
7777
from keras.src.ops.numpy import full_like as full_like
78+
from keras.src.ops.numpy import gcd as gcd
7879
from keras.src.ops.numpy import get_item as get_item
7980
from keras.src.ops.numpy import greater as greater
8081
from keras.src.ops.numpy import greater_equal as greater_equal

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@
187187
from keras.src.ops.numpy import floor_divide as floor_divide
188188
from keras.src.ops.numpy import full as full
189189
from keras.src.ops.numpy import full_like as full_like
190+
from keras.src.ops.numpy import gcd as gcd
190191
from keras.src.ops.numpy import get_item as get_item
191192
from keras.src.ops.numpy import greater as greater
192193
from keras.src.ops.numpy import greater_equal as greater_equal

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
from keras.src.ops.numpy import floor_divide as floor_divide
7676
from keras.src.ops.numpy import full as full
7777
from keras.src.ops.numpy import full_like as full_like
78+
from keras.src.ops.numpy import gcd as gcd
7879
from keras.src.ops.numpy import get_item as get_item
7980
from keras.src.ops.numpy import greater as greater
8081
from keras.src.ops.numpy import greater_equal as greater_equal

keras/src/backend/jax/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,12 @@ def full_like(x, fill_value, dtype=None):
736736
return jnp.full_like(x, fill_value, dtype=dtype)
737737

738738

739+
def gcd(x1, x2):
740+
x1 = convert_to_tensor(x1)
741+
x2 = convert_to_tensor(x2)
742+
return jnp.gcd(x1, x2)
743+
744+
739745
def greater(x1, x2):
740746
x1 = convert_to_tensor(x1)
741747
x2 = convert_to_tensor(x2)

keras/src/backend/numpy/numpy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,14 @@ def full_like(x, fill_value, dtype=None):
662662
return np.full_like(x, fill_value, dtype=dtype)
663663

664664

665+
def gcd(x1, x2):
666+
x1 = convert_to_tensor(x1)
667+
x2 = convert_to_tensor(x2)
668+
669+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
670+
return np.gcd(x1, x2).astype(dtype)
671+
672+
665673
def greater(x1, x2):
666674
return np.greater(x1, x2)
667675

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ NumpyDtypeTest::test_argpartition
99
NumpyDtypeTest::test_array
1010
NumpyDtypeTest::test_bartlett
1111
NumpyDtypeTest::test_blackman
12+
NumpyDtypeTest::test_gcd
1213
NumpyDtypeTest::test_hamming
1314
NumpyDtypeTest::test_hanning
1415
NumpyDtypeTest::test_heaviside
@@ -142,6 +143,7 @@ NumpyTwoInputOpsCorrectnessTest::test_cross
142143
NumpyTwoInputOpsCorrectnessTest::test_digitize
143144
NumpyTwoInputOpsCorrectnessTest::test_divide_no_nan
144145
NumpyTwoInputOpsCorrectnessTest::test_einsum
146+
NumpyTwoInputOpsCorrectnessTest::test_gcd
145147
NumpyTwoInputOpsCorrectnessTest::test_heaviside
146148
NumpyTwoInputOpsCorrectnessTest::test_hypot
147149
NumpyTwoInputOpsCorrectnessTest::test_inner
@@ -165,9 +167,11 @@ NumpyOneInputOpsStaticShapeTest::test_angle
165167
NumpyOneInputOpsStaticShapeTest::test_cbrt
166168
NumpyOneInputOpsStaticShapeTest::test_isneginf
167169
NumpyOneInputOpsStaticShapeTest::test_isposinf
170+
NumpyTwoInputOpsDynamicShapeTest::test_gcd
168171
NumpyTwoInputOpsDynamicShapeTest::test_heaviside
169172
NumpyTwoInputOpsDynamicShapeTest::test_hypot
170173
NumpyTwoInputOpsDynamicShapeTest::test_isin
174+
NumpyTwoInputOpsStaticShapeTest::test_gcd
171175
NumpyTwoInputOpsStaticShapeTest::test_heaviside
172176
NumpyTwoInputOpsStaticShapeTest::test_hypot
173177
NumpyTwoInputOpsStaticShapeTest::test_isin

keras/src/backend/openvino/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,10 @@ def full_like(x, fill_value, dtype=None):
863863
return OpenVINOKerasTensor(res)
864864

865865

866+
def gcd(x1, x2):
867+
raise NotImplementedError("`gcd` is not supported with openvino backend")
868+
869+
866870
def greater(x1, x2):
867871
element_type = None
868872
if isinstance(x1, OpenVINOKerasTensor):

keras/src/backend/tensorflow/numpy.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,43 @@ def full_like(x, fill_value, dtype=None):
15311531
return tf.broadcast_to(fill_value, tf.shape(x))
15321532

15331533

1534+
def gcd(x1, x2):
1535+
x1 = tf.convert_to_tensor(x1)
1536+
x2 = tf.convert_to_tensor(x2)
1537+
1538+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
1539+
x1 = tf.cast(x1, dtype)
1540+
x2 = tf.cast(x2, dtype)
1541+
1542+
if not x1.dtype.is_integer:
1543+
raise TypeError("Arguments to gcd must be integers.")
1544+
1545+
target_shape = tf.broadcast_static_shape(x1.shape, x2.shape)
1546+
x1 = tf.broadcast_to(x1, target_shape)
1547+
x2 = tf.broadcast_to(x2, target_shape)
1548+
1549+
def cond(a, b):
1550+
return tf.reduce_any(b != 0)
1551+
1552+
def body(a, b):
1553+
b_safe = tf.where(tf.equal(b, 0), tf.ones_like(b), b)
1554+
return (
1555+
tf.where(tf.not_equal(b, 0), b, a),
1556+
tf.where(
1557+
tf.not_equal(b, 0),
1558+
tf.math.floormod(a, b_safe),
1559+
tf.zeros_like(b),
1560+
),
1561+
)
1562+
1563+
if dtype not in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:
1564+
x1 = tf.abs(x1)
1565+
x2 = tf.abs(x2)
1566+
1567+
gcd_val, _ = tf.while_loop(cond, body, [x1, x2])
1568+
return gcd_val
1569+
1570+
15341571
def greater(x1, x2):
15351572
x1 = convert_to_tensor(x1)
15361573
x2 = convert_to_tensor(x2)

keras/src/backend/torch/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,12 @@ def full_like(x, fill_value, dtype=None):
839839
return full(shape=x.shape, fill_value=fill_value, dtype=dtype)
840840

841841

842+
def gcd(x1, x2):
843+
x1 = convert_to_tensor(x1)
844+
x2 = convert_to_tensor(x2)
845+
return torch.gcd(x1, x2)
846+
847+
842848
def greater(x1, x2):
843849
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
844850
return torch.greater(x1, x2)

0 commit comments

Comments
 (0)