@@ -1531,6 +1531,43 @@ def full_like(x, fill_value, dtype=None):
1531
1531
return tf .broadcast_to (fill_value , tf .shape (x ))
1532
1532
1533
1533
1534
+ def gcd (x1 , x2 ):
1535
+ x1 = tf .convert_to_tensor (x1 )
1536
+ x2 = tf .convert_to_tensor (x2 )
1537
+
1538
+ dtype = dtypes .result_type (x1 .dtype , x2 .dtype )
1539
+ x1 = tf .cast (x1 , dtype )
1540
+ x2 = tf .cast (x2 , dtype )
1541
+
1542
+ if not x1 .dtype .is_integer :
1543
+ raise TypeError ("Arguments to gcd must be integers." )
1544
+
1545
+ target_shape = tf .broadcast_static_shape (x1 .shape , x2 .shape )
1546
+ x1 = tf .broadcast_to (x1 , target_shape )
1547
+ x2 = tf .broadcast_to (x2 , target_shape )
1548
+
1549
+ def cond (a , b ):
1550
+ return tf .reduce_any (b != 0 )
1551
+
1552
+ def body (a , b ):
1553
+ b_safe = tf .where (tf .equal (b , 0 ), tf .ones_like (b ), b )
1554
+ return (
1555
+ tf .where (tf .not_equal (b , 0 ), b , a ),
1556
+ tf .where (
1557
+ tf .not_equal (b , 0 ),
1558
+ tf .math .floormod (a , b_safe ),
1559
+ tf .zeros_like (b ),
1560
+ ),
1561
+ )
1562
+
1563
+ if dtype not in [tf .uint8 , tf .uint16 , tf .uint32 , tf .uint64 ]:
1564
+ x1 = tf .abs (x1 )
1565
+ x2 = tf .abs (x2 )
1566
+
1567
+ gcd_val , _ = tf .while_loop (cond , body , [x1 , x2 ])
1568
+ return gcd_val
1569
+
1570
+
1534
1571
def greater (x1 , x2 ):
1535
1572
x1 = convert_to_tensor (x1 )
1536
1573
x2 = convert_to_tensor (x2 )
0 commit comments