Skip to content

Commit 6856440

Browse files
committed
Implementation of Math.gcd
1 parent be0a4ea commit 6856440

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_math.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
BIG_INT = 9999992432902008176640000999999
1616
FLOAT_MAX = sys.float_info.max
1717

18+
class MyIndexable(object):
19+
def __init__(self, value):
20+
self.value = value
21+
def __index__(self):
22+
return self.value
23+
1824
""" The next three methods are needed for testing factorials
1925
"""
2026
def count_set_bits(n):
@@ -948,6 +954,54 @@ def test_factorial(self):
948954
self.assertEqual(math.factorial(30), 265252859812191058636308480000000)
949955
self.assertRaises(ValueError, math.factorial, -11.1)
950956

957+
def testGcd(self):
958+
gcd = math.gcd
959+
self.assertEqual(gcd(0, 0), 0)
960+
self.assertEqual(gcd(1, 0), 1)
961+
self.assertEqual(gcd(-1, 0), 1)
962+
self.assertEqual(gcd(0, 1), 1)
963+
self.assertEqual(gcd(0, -1), 1)
964+
self.assertEqual(gcd(7, 1), 1)
965+
self.assertEqual(gcd(7, -1), 1)
966+
self.assertEqual(gcd(-23, 15), 1)
967+
self.assertEqual(gcd(120, 84), 12)
968+
self.assertEqual(gcd(84, -120), 12)
969+
self.assertEqual(gcd(1216342683557601535506311712,
970+
436522681849110124616458784), 32)
971+
c = 652560
972+
x = 434610456570399902378880679233098819019853229470286994367836600566
973+
y = 1064502245825115327754847244914921553977
974+
a = x * c
975+
b = y * c
976+
self.assertEqual(gcd(a, b), c)
977+
self.assertEqual(gcd(b, a), c)
978+
self.assertEqual(gcd(-a, b), c)
979+
self.assertEqual(gcd(b, -a), c)
980+
self.assertEqual(gcd(a, -b), c)
981+
self.assertEqual(gcd(-b, a), c)
982+
self.assertEqual(gcd(-a, -b), c)
983+
self.assertEqual(gcd(-b, -a), c)
984+
c = 576559230871654959816130551884856912003141446781646602790216406874
985+
a = x * c
986+
b = y * c
987+
self.assertEqual(gcd(a, b), c)
988+
self.assertEqual(gcd(b, a), c)
989+
self.assertEqual(gcd(-a, b), c)
990+
self.assertEqual(gcd(b, -a), c)
991+
self.assertEqual(gcd(a, -b), c)
992+
self.assertEqual(gcd(-b, a), c)
993+
self.assertEqual(gcd(-a, -b), c)
994+
self.assertEqual(gcd(-b, -a), c)
995+
996+
self.assertRaises(TypeError, gcd, 120.0, 84)
997+
self.assertRaises(TypeError, gcd, 120, 84.0)
998+
self.assertEqual(gcd(MyIndexable(120), MyIndexable(84)), 12)
999+
1000+
# test of specializations
1001+
self.assertRaises(TypeError, gcd, 120, MyIndexable(6.0))
1002+
self.assertRaises(TypeError, gcd, 'ahoj', 1)
1003+
self.assertEqual(gcd(MyIndexable(True), MyIndexable(84)), 1)
1004+
9511005
def test_floor(self):
9521006
class TestFloor:
9531007
def __floor__(self):

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MathModuleBuiltins.java

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import com.oracle.graal.python.builtins.objects.ints.PInt;
4444
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
4545
import com.oracle.graal.python.nodes.PBaseNode;
46+
import com.oracle.graal.python.nodes.PGuards;
4647
import com.oracle.graal.python.nodes.SpecialMethodNames;
4748
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
4849
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
@@ -129,6 +130,50 @@ public double toDouble(Object x) {
129130
}
130131
}
131132

133+
@TypeSystemReference(PythonArithmeticTypes.class)
134+
@ImportStatic(MathGuards.class)
135+
static abstract class ConvertToIntNode extends PBaseNode {
136+
137+
@Child private LookupAndCallUnaryNode callIndexNode;
138+
139+
abstract Object execute(Object x);
140+
141+
public static ConvertToIntNode create() {
142+
return MathModuleBuiltinsFactory.ConvertToIntNodeGen.create();
143+
}
144+
145+
@Specialization
146+
public long toInt(long x) {
147+
return x;
148+
}
149+
150+
@Specialization
151+
public PInt toInt(PInt x) {
152+
return x;
153+
}
154+
155+
@Specialization
156+
public long toInt(double x) {
157+
throw raise(TypeError, "'float' object cannot be interpreted as an integer");
158+
}
159+
160+
@Specialization(guards = "!isNumber(x)")
161+
public Object toInt(Object x) {
162+
if (callIndexNode == null) {
163+
CompilerDirectives.transferToInterpreterAndInvalidate();
164+
callIndexNode = insert(LookupAndCallUnaryNode.create(SpecialMethodNames.__INDEX__));
165+
}
166+
Object result = callIndexNode.executeObject(x);
167+
if (result == PNone.NONE) {
168+
throw raise(TypeError, "'%p' object cannot be interpreted as an integer", x);
169+
}
170+
if (!PGuards.isInteger(result) && !PGuards.isPInt(result) && !(result instanceof Boolean)) {
171+
throw raise(TypeError, " __index__ returned non-int (type %p)", result);
172+
}
173+
return result;
174+
}
175+
}
176+
132177
public abstract static class MathUnaryBuiltinNode extends PythonUnaryBuiltinNode {
133178

134179
public void checkMathRangeError(boolean con) {
@@ -1003,6 +1048,79 @@ public PTuple frexpO(Object value,
10031048
}
10041049
}
10051050

1051+
@Builtin(name = "gcd", fixedNumOfArguments = 2)
1052+
@TypeSystemReference(PythonArithmeticTypes.class)
1053+
@GenerateNodeFactory
1054+
@ImportStatic(MathGuards.class)
1055+
public abstract static class GcdNode extends PythonBinaryBuiltinNode {
1056+
1057+
private long count(long a, long b) {
1058+
if (b == 0) {
1059+
return a;
1060+
}
1061+
return count(b, a % b);
1062+
}
1063+
1064+
@Specialization
1065+
long gcd(long x, long y) {
1066+
return Math.abs(count(x, y));
1067+
}
1068+
1069+
@Specialization
1070+
PInt gcd(long x, PInt y) {
1071+
return factory().createInt(BigInteger.valueOf(x).gcd(y.getValue()));
1072+
}
1073+
1074+
@Specialization
1075+
PInt gcd(PInt x, long y) {
1076+
return factory().createInt(x.getValue().gcd(BigInteger.valueOf(y)));
1077+
}
1078+
1079+
@Specialization
1080+
PInt gcd(PInt x, PInt y) {
1081+
return factory().createInt(x.getValue().gcd(y.getValue()));
1082+
}
1083+
1084+
@Specialization
1085+
int gcd(double x, double y) {
1086+
throw raise(TypeError, "'float' object cannot be interpreted as an integer");
1087+
}
1088+
1089+
@Specialization
1090+
int gcd(long x, double y) {
1091+
throw raise(TypeError, "'float' object cannot be interpreted as an integer");
1092+
}
1093+
1094+
@Specialization
1095+
int gcd(double x, long y) {
1096+
throw raise(TypeError, "'float' object cannot be interpreted as an integer");
1097+
}
1098+
1099+
@Specialization
1100+
int gcd(double x, PInt y) {
1101+
throw raise(TypeError, "'float' object cannot be interpreted as an integer");
1102+
}
1103+
1104+
@Specialization
1105+
int gcd(PInt x, double y) {
1106+
throw raise(TypeError, "'float' object cannot be interpreted as an integer");
1107+
}
1108+
1109+
@Specialization(guards = "!isNumber(x) || !isNumber(y)")
1110+
Object gcd(Object x, Object y,
1111+
@Cached("create()") ConvertToIntNode xConvert,
1112+
@Cached("create()") ConvertToIntNode yConvert,
1113+
@Cached("create()") GcdNode recursiveNode) {
1114+
Object xValue = xConvert.execute(x);
1115+
Object yValue = yConvert.execute(y);
1116+
return recursiveNode.execute(xValue, yValue);
1117+
}
1118+
1119+
public static GcdNode create() {
1120+
return MathModuleBuiltinsFactory.GcdNodeFactory.create();
1121+
}
1122+
}
1123+
10061124
@Builtin(name = "acos", fixedNumOfArguments = 1, doc = "Return the arc cosine (measured in radians) of x.")
10071125
@GenerateNodeFactory
10081126
public abstract static class AcosNode extends MathDoubleUnaryBuiltinNode {

0 commit comments

Comments
 (0)