Skip to content

Commit dc03c36

Browse files
committed
[GR-20982] Improve specializations with overflow handling for integer add, sub, and mul
PullRequest: graalpython/805
2 parents 48a6f25 + 2a08117 commit dc03c36

File tree

3 files changed

+83
-48
lines changed

3 files changed

+83
-48
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ints/IntBuiltins.java

Lines changed: 73 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,24 @@ long addLong(long left, long right) {
163163
}
164164

165165
@Specialization
166-
PInt addPInt(long left, long right) {
167-
return factory().createInt(op(BigInteger.valueOf(left), BigInteger.valueOf(right)));
166+
Object addLongWithOverflow(long x, long y) {
167+
/* Inlined version of Math.addExact(x, y) with BigInteger fallback. */
168+
long r = x + y;
169+
// HD 2-12 Overflow iff both arguments have the opposite sign of the result
170+
if (((x ^ r) & (y ^ r)) < 0) {
171+
return factory().createInt(op(PInt.longToBigInteger(x), PInt.longToBigInteger(y)));
172+
}
173+
return r;
168174
}
169175

170176
@Specialization
171177
PInt add(PInt left, long right) {
172-
return add(left, factory().createInt(right));
178+
return factory().createInt(op(left.getValue(), PInt.longToBigInteger(right)));
173179
}
174180

175181
@Specialization
176182
PInt add(long left, PInt right) {
177-
return add(factory().createInt(left), right);
183+
return factory().createInt(op(PInt.longToBigInteger(left), right.getValue()));
178184
}
179185

180186
@Specialization
@@ -220,18 +226,25 @@ long doLL(long x, long y) throws ArithmeticException {
220226
}
221227

222228
@Specialization
223-
PInt doLLOvf(long x, long y) {
224-
return factory().createInt(op(BigInteger.valueOf(x), BigInteger.valueOf(y)));
229+
Object doLongWithOverflow(long x, long y) {
230+
/* Inlined version of Math.subtractExact(x, y) with BigInteger fallback. */
231+
long r = x - y;
232+
// HD 2-12 Overflow iff the arguments have different signs and
233+
// the sign of the result is different than the sign of x
234+
if (((x ^ y) & (x ^ r)) < 0) {
235+
return factory().createInt(op(PInt.longToBigInteger(x), PInt.longToBigInteger(y)));
236+
}
237+
return r;
225238
}
226239

227240
@Specialization
228241
PInt doPIntLong(PInt left, long right) {
229-
return doPIntPInt(left, factory().createInt(right));
242+
return factory().createInt(op(left.getValue(), PInt.longToBigInteger(right)));
230243
}
231244

232245
@Specialization
233246
PInt doLongPInt(long left, PInt right) {
234-
return doPIntPInt(factory().createInt(left), right);
247+
return factory().createInt(op(PInt.longToBigInteger(left), right.getValue()));
235248
}
236249

237250
@Specialization
@@ -272,18 +285,25 @@ long doLL(long y, long x) throws ArithmeticException {
272285
}
273286

274287
@Specialization
275-
PInt doLLOvf(long y, long x) {
276-
return factory().createInt(op(BigInteger.valueOf(x), BigInteger.valueOf(y)));
288+
Object doLongWithOverflow(long y, long x) {
289+
/* Inlined version of Math.subtractExact(x, y) with BigInteger fallback. */
290+
long r = x - y;
291+
// HD 2-12 Overflow iff the arguments have different signs and
292+
// the sign of the result is different than the sign of x
293+
if (((x ^ y) & (x ^ r)) < 0) {
294+
return factory().createInt(op(PInt.longToBigInteger(x), PInt.longToBigInteger(y)));
295+
}
296+
return r;
277297
}
278298

279299
@Specialization
280300
PInt doPIntLong(PInt right, long left) {
281-
return doPIntPInt(factory().createInt(left), right);
301+
return factory().createInt(op(PInt.longToBigInteger(left), right.getValue()));
282302
}
283303

284304
@Specialization
285305
PInt doLongPInt(long right, PInt left) {
286-
return doPIntPInt(factory().createInt(right), left);
306+
return factory().createInt(op(PInt.longToBigInteger(right), left.getValue()));
287307
}
288308

289309
@Specialization
@@ -327,15 +347,18 @@ public abstract static class TrueDivNode extends PythonBinaryBuiltinNode {
327347

328348
@Specialization
329349
double doPI(long left, PInt right) {
330-
return doPP(factory().createInt(left), right);
350+
if (right.isZero()) {
351+
throw raise(PythonErrorType.ZeroDivisionError, "division by zero");
352+
}
353+
return op(PInt.longToBigInteger(left), right.getValue());
331354
}
332355

333356
@Specialization
334357
double doPL(PInt left, long right) {
335358
if (right == 0) {
336359
throw raise(PythonErrorType.ZeroDivisionError, "division by zero");
337360
}
338-
return doPP(left, factory().createInt(right));
361+
return op(left.getValue(), PInt.longToBigInteger(right));
339362
}
340363

341364
@Specialization
@@ -391,7 +414,10 @@ public abstract static class RTrueDivNode extends PythonBinaryBuiltinNode {
391414

392415
@Specialization
393416
double doPL(PInt right, long left) {
394-
return doPP(right, factory().createInt(left));
417+
if (right.isZero()) {
418+
throw raise(PythonErrorType.ZeroDivisionError, "division by zero");
419+
}
420+
return op(PInt.longToBigInteger(left), right.getValue());
395421
}
396422

397423
@Specialization
@@ -477,13 +503,13 @@ long doLPiOvf(long left, PInt right) {
477503
@Specialization
478504
PInt doPiL(PInt left, int right) {
479505
raiseDivisionByZero(right == 0);
480-
return factory().createInt(op(left.getValue(), BigInteger.valueOf(right)));
506+
return factory().createInt(op(left.getValue(), PInt.longToBigInteger(right)));
481507
}
482508

483509
@Specialization
484510
PInt doPiL(PInt left, long right) {
485511
raiseDivisionByZero(right == 0);
486-
return factory().createInt(op(left.getValue(), BigInteger.valueOf(right)));
512+
return factory().createInt(op(left.getValue(), PInt.longToBigInteger(right)));
487513
}
488514

489515
@Specialization
@@ -530,13 +556,13 @@ long doLL(long right, long left) {
530556
@Specialization
531557
PInt doPiL(PInt right, long left) {
532558
raiseDivisionByZero(right.isZero());
533-
return factory().createInt(op(BigInteger.valueOf(left), right.getValue()));
559+
return factory().createInt(op(PInt.longToBigInteger(left), right.getValue()));
534560
}
535561

536562
@Specialization
537563
PInt doLPi(long right, PInt left) {
538564
raiseDivisionByZero(right == 0);
539-
return factory().createInt(op(left.getValue(), BigInteger.valueOf(right)));
565+
return factory().createInt(op(left.getValue(), PInt.longToBigInteger(right)));
540566
}
541567

542568
@Specialization
@@ -576,13 +602,13 @@ long doLL(long left, long right) {
576602
@Specialization
577603
PInt doLPi(long left, PInt right) {
578604
raiseDivisionByZero(right.isZero());
579-
return factory().createInt(op(BigInteger.valueOf(left), right.getValue()));
605+
return factory().createInt(op(PInt.longToBigInteger(left), right.getValue()));
580606
}
581607

582608
@Specialization(guards = "right >= 0")
583609
PInt doPiL(PInt left, long right) {
584610
raiseDivisionByZero(right == 0);
585-
return factory().createInt(op(left.getValue(), BigInteger.valueOf(right)));
611+
return factory().createInt(op(left.getValue(), PInt.longToBigInteger(right)));
586612
}
587613

588614
@Specialization(guards = "right.isZeroOrPositive()")
@@ -593,7 +619,7 @@ PInt doPiPi(PInt left, PInt right) {
593619

594620
@Specialization(guards = "right < 0")
595621
PInt doPiLNeg(PInt left, long right) {
596-
return factory().createInt(opNeg(left.getValue(), BigInteger.valueOf(right)));
622+
return factory().createInt(opNeg(left.getValue(), PInt.longToBigInteger(right)));
597623
}
598624

599625
@Specialization(guards = "!right.isZeroOrPositive()")
@@ -642,17 +668,21 @@ long doLL(long x, long y) {
642668
}
643669

644670
@Specialization
645-
PInt doLLOvf(long x, long y) {
671+
Object doLongWithOverflow(long x, long y) {
672+
/* Inlined version of Math.multiplyExact(x, y) with BigInteger fallback. */
646673
long r = x * y;
647674
long ax = Math.abs(x);
648675
long ay = Math.abs(y);
649676
if (((ax | ay) >>> 31 != 0)) {
650-
int leadingZeros = Long.numberOfLeadingZeros(ax) + Long.numberOfLeadingZeros(ay);
651-
if (leadingZeros < 66) {
652-
return factory().createInt(mul(BigInteger.valueOf(x), BigInteger.valueOf(y)));
677+
// Some bits greater than 2^31 that might cause overflow
678+
// Check the result using the divide operator
679+
// and check for the special case of Long.MIN_VALUE * -1
680+
if (((y != 0) && (r / y != x)) ||
681+
(x == Long.MIN_VALUE && y == -1)) {
682+
return factory().createInt(mul(PInt.longToBigInteger(x), PInt.longToBigInteger(y)));
653683
}
654684
}
655-
return factory().createInt(r);
685+
return r;
656686
}
657687

658688
@Specialization(guards = "right == 0")
@@ -668,7 +698,7 @@ PInt doPIntLongOne(PInt left, @SuppressWarnings("unused") long right) {
668698

669699
@Specialization(guards = {"right != 0", "right != 1"})
670700
PInt doPIntLong(PInt left, long right) {
671-
return factory().createInt(mul(left.getValue(), BigInteger.valueOf(right)));
701+
return factory().createInt(mul(left.getValue(), PInt.longToBigInteger(right)));
672702
}
673703

674704
@Specialization
@@ -733,7 +763,7 @@ int doIntegerFast(int left, int right, @SuppressWarnings("unused") PNone none) {
733763

734764
@Specialization(guards = "right >= 0")
735765
PInt doInteger(int left, int right, @SuppressWarnings("unused") PNone none) {
736-
return factory().createInt(op(BigInteger.valueOf(left), right));
766+
return factory().createInt(op(PInt.longToBigInteger(left), right));
737767
}
738768

739769
@Specialization(guards = "right >= 0", rewriteOn = ArithmeticException.class)
@@ -773,7 +803,7 @@ long doLongFast(long left, long right, @SuppressWarnings("unused") PNone none) {
773803

774804
@Specialization(guards = "right >= 0")
775805
PInt doLong(long left, long right, @SuppressWarnings("unused") PNone none) {
776-
return factory().createInt(op(BigInteger.valueOf(left), right));
806+
return factory().createInt(op(PInt.longToBigInteger(left), right));
777807
}
778808

779809
@Specialization
@@ -839,7 +869,7 @@ private BigInteger op(BigInteger a, long b) {
839869
} else if (value == 1) {
840870
return BigInteger.ONE;
841871
} else if (value == -1) {
842-
return (b & 1) != 0 ? BigInteger.valueOf(-1) : BigInteger.ONE;
872+
return (b & 1) != 0 ? PInt.longToBigInteger(-1) : BigInteger.ONE;
843873
}
844874
} catch (ArithmeticException e) {
845875
// fall through to normal computation
@@ -888,9 +918,9 @@ long pos(long arg) {
888918
PInt posOvf(long arg) throws IllegalArgumentException {
889919
long result = Math.abs(arg);
890920
if (result < 0) {
891-
return factory().createInt(op(BigInteger.valueOf(arg)));
921+
return factory().createInt(op(PInt.longToBigInteger(arg)));
892922
} else {
893-
return factory().createInt(BigInteger.valueOf(arg));
923+
return factory().createInt(PInt.longToBigInteger(arg));
894924
}
895925
}
896926

@@ -986,7 +1016,7 @@ long neg(long arg) {
9861016

9871017
@Specialization
9881018
PInt negOvf(long arg) {
989-
BigInteger value = arg == Long.MIN_VALUE ? negate(BigInteger.valueOf(arg)) : BigInteger.valueOf(-arg);
1019+
BigInteger value = arg == Long.MIN_VALUE ? negate(PInt.longToBigInteger(arg)) : PInt.longToBigInteger(-arg);
9901020
return factory().createInt(value);
9911021
}
9921022

@@ -1084,7 +1114,7 @@ Object doIIOvf(int left, int right) {
10841114
try {
10851115
return leftShiftExact(left, right);
10861116
} catch (ArithmeticException e) {
1087-
return doGuardedBiI(BigInteger.valueOf(left), right);
1117+
return doGuardedBiI(PInt.longToBigInteger(left), right);
10881118
}
10891119
}
10901120

@@ -1102,7 +1132,7 @@ Object doLLOvf(long left, long right) {
11021132
} catch (ArithmeticException e) {
11031133
int rightI = (int) right;
11041134
if (rightI == right) {
1105-
return factory().createInt(op(BigInteger.valueOf(left), rightI));
1135+
return factory().createInt(op(PInt.longToBigInteger(left), rightI));
11061136
} else {
11071137
throw raise(PythonErrorType.OverflowError);
11081138
}
@@ -1113,7 +1143,7 @@ Object doLLOvf(long left, long right) {
11131143
PInt doLPi(long left, PInt right) {
11141144
raiseNegativeShiftCount(!right.isZeroOrPositive());
11151145
try {
1116-
return factory().createInt(op(BigInteger.valueOf(left), right.intValue()));
1146+
return factory().createInt(op(PInt.longToBigInteger(left), right.intValue()));
11171147
} catch (ArithmeticException e) {
11181148
throw raise(PythonErrorType.OverflowError);
11191149
}
@@ -1191,13 +1221,13 @@ long doLL(long left, long right) {
11911221
@Specialization
11921222
PInt doIPi(int left, PInt right) {
11931223
raiseNegativeShiftCount(!right.isZeroOrPositive());
1194-
return factory().createInt(op(BigInteger.valueOf(left), right.intValue()));
1224+
return factory().createInt(op(PInt.longToBigInteger(left), right.intValue()));
11951225
}
11961226

11971227
@Specialization
11981228
PInt doLPi(long left, PInt right) {
11991229
raiseNegativeShiftCount(!right.isZeroOrPositive());
1200-
return factory().createInt(op(BigInteger.valueOf(left), right.intValue()));
1230+
return factory().createInt(op(PInt.longToBigInteger(left), right.intValue()));
12011231
}
12021232

12031233
@Specialization
@@ -1266,12 +1296,12 @@ long doInteger(long left, long right) {
12661296

12671297
@Specialization
12681298
PInt doPInt(long left, PInt right) {
1269-
return factory().createInt(op(BigInteger.valueOf(left), right.getValue()));
1299+
return factory().createInt(op(PInt.longToBigInteger(left), right.getValue()));
12701300
}
12711301

12721302
@Specialization
12731303
PInt doPInt(PInt left, long right) {
1274-
return factory().createInt(op(left.getValue(), BigInteger.valueOf(right)));
1304+
return factory().createInt(op(left.getValue(), PInt.longToBigInteger(right)));
12751305
}
12761306

12771307
@Specialization
@@ -1449,10 +1479,10 @@ boolean eqVoidPtrPInt(PythonNativeVoidPtr a, PInt b,
14491479
long ptrVal = lib.asPointer(a.object);
14501480
if (ptrVal < 0) {
14511481
// pointers are considered unsigned
1452-
BigInteger bi = BigInteger.valueOf(ptrVal).add(BigInteger.ONE.shiftLeft(64));
1482+
BigInteger bi = PInt.longToBigInteger(ptrVal).add(BigInteger.ONE.shiftLeft(64));
14531483
return bi.equals(b.getValue());
14541484
}
1455-
return BigInteger.valueOf(ptrVal).equals(b.getValue());
1485+
return PInt.longToBigInteger(ptrVal).equals(b.getValue());
14561486
} catch (UnsupportedMessageException e) {
14571487
// fall through
14581488
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ints/PInt.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ public int compareTo(long i) {
223223

224224
@TruffleBoundary
225225
private static final int compareTo(BigInteger left, long right) {
226-
return left.compareTo(BigInteger.valueOf(right));
226+
return left.compareTo(longToBigInteger(right));
227227
}
228228

229229
@Override
@@ -236,6 +236,11 @@ private static final String toString(BigInteger value) {
236236
return value.toString();
237237
}
238238

239+
@TruffleBoundary
240+
public static final BigInteger longToBigInteger(long value) {
241+
return BigInteger.valueOf(value);
242+
}
243+
239244
public double doubleValue() {
240245
return doubleValue(value);
241246
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/object/PythonObjectFactory.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,23 +220,23 @@ public SuperObject createSuperObject(LazyPythonClass self) {
220220
* Primitive types
221221
*/
222222
public PInt createInt(int value) {
223-
return trace(new PInt(PythonBuiltinClassType.PInt, BigInteger.valueOf(value)));
223+
return trace(new PInt(PythonBuiltinClassType.PInt, PInt.longToBigInteger(value)));
224224
}
225225

226226
public PInt createInt(long value) {
227-
return trace(new PInt(PythonBuiltinClassType.PInt, BigInteger.valueOf(value)));
227+
return trace(new PInt(PythonBuiltinClassType.PInt, PInt.longToBigInteger(value)));
228228
}
229229

230230
public PInt createInt(BigInteger value) {
231231
return trace(new PInt(PythonBuiltinClassType.PInt, value));
232232
}
233233

234234
public Object createInt(LazyPythonClass cls, int value) {
235-
return trace(new PInt(cls, BigInteger.valueOf(value)));
235+
return trace(new PInt(cls, PInt.longToBigInteger(value)));
236236
}
237237

238238
public Object createInt(LazyPythonClass cls, long value) {
239-
return trace(new PInt(cls, BigInteger.valueOf(value)));
239+
return trace(new PInt(cls, PInt.longToBigInteger(value)));
240240
}
241241

242242
public PInt createInt(LazyPythonClass cls, BigInteger value) {

0 commit comments

Comments
 (0)