Skip to content

Commit b9dcb93

Browse files
committed
Implementation of math.isqrt
1 parent 0da4104 commit b9dcb93

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MathModuleBuiltins.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2548,4 +2548,77 @@ public double count(double x) {
25482548
}
25492549

25502550
}
2551+
2552+
@Builtin(name = "isqrt", minNumOfPositionalArgs = 1)
2553+
@TypeSystemReference(PythonArithmeticTypes.class)
2554+
@GenerateNodeFactory
2555+
@ImportStatic(MathGuards.class)
2556+
public abstract static class IsqrtNode extends PythonUnaryBuiltinNode {
2557+
2558+
@Specialization
2559+
Object isqrtLong(long x) {
2560+
raiseIfNegative(x < 0);
2561+
return makeInt(op(PInt.longToBigInteger(x)));
2562+
}
2563+
2564+
@Specialization
2565+
Object isqrtPInt(PInt x) {
2566+
raiseIfNegative(x.isNegative());
2567+
return makeInt(op(x.getValue()));
2568+
}
2569+
2570+
@Specialization(guards = "!isInteger(x)")
2571+
Object doGeneral(VirtualFrame frame, Object x,
2572+
@Cached("createBinaryProfile()") ConditionProfile hasFrame,
2573+
@CachedLibrary(limit = "1") PythonObjectLibrary lib,
2574+
@Cached IsqrtNode recursiveNode) {
2575+
return recursiveNode.execute(frame, lib.asIndexWithFrame(x, hasFrame, frame));
2576+
}
2577+
2578+
private Object makeInt(BigInteger i) {
2579+
try {
2580+
return PInt.intValueExact(i);
2581+
} catch (ArithmeticException e) {
2582+
// does not fit int, so try long
2583+
}
2584+
try {
2585+
return PInt.longValueExact(i);
2586+
} catch (ArithmeticException e) {
2587+
// does not fit long either, create PInt
2588+
}
2589+
return factory().createInt(i);
2590+
}
2591+
2592+
@TruffleBoundary
2593+
private BigInteger op(BigInteger x) {
2594+
// assumes x >= 0
2595+
if (x.equals(BigInteger.ZERO) || x.equals(BigInteger.ONE)) {
2596+
return x;
2597+
}
2598+
BigInteger start = BigInteger.ONE;
2599+
BigInteger end = x;
2600+
BigInteger result = BigInteger.ZERO;
2601+
BigInteger two = BigInteger.valueOf(2);
2602+
while (start.compareTo(end) <= 0) {
2603+
BigInteger mid = (start.add(end).divide(two));
2604+
int cmp = mid.multiply(mid).compareTo(x);
2605+
if (cmp == 0) {
2606+
return mid;
2607+
}
2608+
if (cmp < 0) {
2609+
start = mid.add(BigInteger.ONE);
2610+
result = mid;
2611+
} else {
2612+
end = mid.subtract(BigInteger.ONE);
2613+
}
2614+
}
2615+
return result;
2616+
}
2617+
2618+
private void raiseIfNegative(boolean condition) {
2619+
if (condition) {
2620+
throw raise(ValueError, ErrorMessages.MUST_BE_NON_NEGATIVE, "isqrt() argument");
2621+
}
2622+
}
2623+
}
25512624
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ints/PInt.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ public int intValueExact() {
355355
}
356356

357357
@TruffleBoundary
358-
private static int intValueExact(BigInteger value) {
358+
public static int intValueExact(BigInteger value) {
359359
return value.intValueExact();
360360
}
361361

@@ -373,7 +373,7 @@ public long longValueExact() throws ArithmeticException {
373373
}
374374

375375
@TruffleBoundary
376-
static long longValueExact(BigInteger value) throws ArithmeticException {
376+
public static long longValueExact(BigInteger value) throws ArithmeticException {
377377
return value.longValueExact();
378378
}
379379

0 commit comments

Comments
 (0)