Skip to content

Commit 1c5eb37

Browse files
fabioromano1rgiulietti
authored andcommitted
8355719: Reduce memory consumption of BigInteger.pow()
Reviewed-by: rgiulietti
1 parent 601f05e commit 1c5eb37

File tree

2 files changed

+280
-98
lines changed

2 files changed

+280
-98
lines changed

src/java.base/share/classes/java/math/BigInteger.java

Lines changed: 91 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 1996, 2024, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 1996, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -1246,6 +1246,16 @@ else if (val < 0 && val >= -MAX_CONSTANT)
12461246
return new BigInteger(val);
12471247
}
12481248

1249+
/**
1250+
* Constructs a BigInteger with magnitude specified by the long,
1251+
* which may not be zero, and the signum specified by the int.
1252+
*/
1253+
private BigInteger(long mag, int signum) {
1254+
assert mag != 0 && signum != 0;
1255+
this.signum = signum;
1256+
this.mag = toMagArray(mag);
1257+
}
1258+
12491259
/**
12501260
* Constructs a BigInteger with the specified value, which may not be zero.
12511261
*/
@@ -1256,16 +1266,14 @@ private BigInteger(long val) {
12561266
} else {
12571267
signum = 1;
12581268
}
1269+
mag = toMagArray(val);
1270+
}
12591271

1260-
int highWord = (int)(val >>> 32);
1261-
if (highWord == 0) {
1262-
mag = new int[1];
1263-
mag[0] = (int)val;
1264-
} else {
1265-
mag = new int[2];
1266-
mag[0] = highWord;
1267-
mag[1] = (int)val;
1268-
}
1272+
private static int[] toMagArray(long mag) {
1273+
int highWord = (int) (mag >>> 32);
1274+
return highWord == 0
1275+
? new int[] { (int) mag }
1276+
: new int[] { highWord, (int) mag };
12691277
}
12701278

12711279
/**
@@ -2589,116 +2597,101 @@ public BigInteger pow(int exponent) {
25892597
if (exponent < 0) {
25902598
throw new ArithmeticException("Negative exponent");
25912599
}
2592-
if (signum == 0) {
2593-
return (exponent == 0 ? ONE : this);
2594-
}
2600+
if (exponent == 0 || this.equals(ONE))
2601+
return ONE;
2602+
2603+
if (signum == 0 || exponent == 1)
2604+
return this;
25952605

2596-
BigInteger partToSquare = this.abs();
2606+
BigInteger base = this.abs();
2607+
final boolean negative = signum < 0 && (exponent & 1) == 1;
25972608

25982609
// Factor out powers of two from the base, as the exponentiation of
25992610
// these can be done by left shifts only.
26002611
// The remaining part can then be exponentiated faster. The
26012612
// powers of two will be multiplied back at the end.
2602-
int powersOfTwo = partToSquare.getLowestSetBit();
2603-
long bitsToShiftLong = (long)powersOfTwo * exponent;
2604-
if (bitsToShiftLong > Integer.MAX_VALUE) {
2613+
final int powersOfTwo = base.getLowestSetBit();
2614+
final long bitsToShiftLong = (long) powersOfTwo * exponent;
2615+
final int bitsToShift = (int) bitsToShiftLong;
2616+
if (bitsToShift != bitsToShiftLong) {
26052617
reportOverflow();
26062618
}
2607-
int bitsToShift = (int)bitsToShiftLong;
26082619

2609-
int remainingBits;
2610-
2611-
// Factor the powers of two out quickly by shifting right, if needed.
2612-
if (powersOfTwo > 0) {
2613-
partToSquare = partToSquare.shiftRight(powersOfTwo);
2614-
remainingBits = partToSquare.bitLength();
2615-
if (remainingBits == 1) { // Nothing left but +/- 1?
2616-
if (signum < 0 && (exponent&1) == 1) {
2617-
return NEGATIVE_ONE.shiftLeft(bitsToShift);
2618-
} else {
2619-
return ONE.shiftLeft(bitsToShift);
2620-
}
2621-
}
2622-
} else {
2623-
remainingBits = partToSquare.bitLength();
2624-
if (remainingBits == 1) { // Nothing left but +/- 1?
2625-
if (signum < 0 && (exponent&1) == 1) {
2626-
return NEGATIVE_ONE;
2627-
} else {
2628-
return ONE;
2629-
}
2630-
}
2631-
}
2620+
// Factor the powers of two out quickly by shifting right.
2621+
base = base.shiftRight(powersOfTwo);
2622+
final int remainingBits = base.bitLength();
2623+
if (remainingBits == 1) // Nothing left but +/- 1?
2624+
return (negative ? NEGATIVE_ONE : ONE).shiftLeft(bitsToShift);
26322625

26332626
// This is a quick way to approximate the size of the result,
26342627
// similar to doing log2[n] * exponent. This will give an upper bound
26352628
// of how big the result can be, and which algorithm to use.
2636-
long scaleFactor = (long)remainingBits * exponent;
2629+
final long scaleFactor = (long) remainingBits * exponent;
26372630

26382631
// Use slightly different algorithms for small and large operands.
2639-
// See if the result will safely fit into a long. (Largest 2^63-1)
2640-
if (partToSquare.mag.length == 1 && scaleFactor <= 62) {
2641-
// Small number algorithm. Everything fits into a long.
2642-
int newSign = (signum <0 && (exponent&1) == 1 ? -1 : 1);
2643-
long result = 1;
2644-
long baseToPow2 = partToSquare.mag[0] & LONG_MASK;
2645-
2646-
int workingExponent = exponent;
2647-
2648-
// Perform exponentiation using repeated squaring trick
2649-
while (workingExponent != 0) {
2650-
if ((workingExponent & 1) == 1) {
2651-
result = result * baseToPow2;
2652-
}
2653-
2654-
if ((workingExponent >>>= 1) != 0) {
2655-
baseToPow2 = baseToPow2 * baseToPow2;
2656-
}
2657-
}
2632+
// See if the result will safely fit into an unsigned long. (Largest 2^64-1)
2633+
if (scaleFactor <= Long.SIZE) {
2634+
// Small number algorithm. Everything fits into an unsigned long.
2635+
final int newSign = negative ? -1 : 1;
2636+
final long result = unsignedLongPow(base.mag[0] & LONG_MASK, exponent);
26582637

26592638
// Multiply back the powers of two (quickly, by shifting left)
2660-
if (powersOfTwo > 0) {
2661-
if (bitsToShift + scaleFactor <= 62) { // Fits in long?
2662-
return valueOf((result << bitsToShift) * newSign);
2663-
} else {
2664-
return valueOf(result*newSign).shiftLeft(bitsToShift);
2665-
}
2666-
} else {
2667-
return valueOf(result*newSign);
2668-
}
2669-
} else {
2670-
if ((long)bitLength() * exponent / Integer.SIZE > MAX_MAG_LENGTH) {
2671-
reportOverflow();
2672-
}
2639+
return bitsToShift + scaleFactor <= Long.SIZE // Fits in long?
2640+
? new BigInteger(result << bitsToShift, newSign)
2641+
: new BigInteger(result, newSign).shiftLeft(bitsToShift);
2642+
}
26732643

2674-
// Large number algorithm. This is basically identical to
2675-
// the algorithm above, but calls multiply() and square()
2676-
// which may use more efficient algorithms for large numbers.
2677-
BigInteger answer = ONE;
2644+
if ((bitLength() - 1L) * exponent >= Integer.MAX_VALUE) {
2645+
reportOverflow();
2646+
}
26782647

2679-
int workingExponent = exponent;
2680-
// Perform exponentiation using repeated squaring trick
2681-
while (workingExponent != 0) {
2682-
if ((workingExponent & 1) == 1) {
2683-
answer = answer.multiply(partToSquare);
2684-
}
2648+
// Large number algorithm. This is basically identical to
2649+
// the algorithm above, but calls multiply()
2650+
// which may use more efficient algorithms for large numbers.
2651+
BigInteger answer = ONE;
26852652

2686-
if ((workingExponent >>>= 1) != 0) {
2687-
partToSquare = partToSquare.square();
2688-
}
2689-
}
2690-
// Multiply back the (exponentiated) powers of two (quickly,
2691-
// by shifting left)
2692-
if (powersOfTwo > 0) {
2693-
answer = answer.shiftLeft(bitsToShift);
2694-
}
2653+
final int expZeros = Integer.numberOfLeadingZeros(exponent);
2654+
int workingExp = exponent << expZeros;
2655+
// Perform exponentiation using repeated squaring trick
2656+
// The loop relies on this invariant:
2657+
// base^exponent == answer^(2^expLen) * base^(workingExp >>> (32-expLen))
2658+
for (int expLen = Integer.SIZE - expZeros; expLen > 0; expLen--) {
2659+
answer = answer.multiply(answer);
2660+
if (workingExp < 0) // leading bit is set
2661+
answer = answer.multiply(base);
26952662

2696-
if (signum < 0 && (exponent&1) == 1) {
2697-
return answer.negate();
2698-
} else {
2699-
return answer;
2700-
}
2663+
workingExp <<= 1;
2664+
}
2665+
2666+
// Multiply back the (exponentiated) powers of two (quickly,
2667+
// by shifting left)
2668+
answer = answer.shiftLeft(bitsToShift);
2669+
return negative ? answer.negate() : answer;
2670+
}
2671+
2672+
/**
2673+
* Computes {@code x^n} using repeated squaring trick.
2674+
* Assumes {@code x != 0 && x^n < 2^Long.SIZE}.
2675+
*/
2676+
static long unsignedLongPow(long x, int n) {
2677+
if (x == 1L || n == 0)
2678+
return 1L;
2679+
2680+
if (x == 2L)
2681+
return 1L << n;
2682+
2683+
/*
2684+
* The method assumption means that n <= 40 here.
2685+
* Thus, the loop body executes at most 5 times.
2686+
*/
2687+
long pow = 1L;
2688+
for (; n != 1; n >>>= 1) {
2689+
if ((n & 1) != 0)
2690+
pow *= x;
2691+
2692+
x *= x;
27012693
}
2694+
return pow * x;
27022695
}
27032696

27042697
/**

0 commit comments

Comments
 (0)