Skip to content

Commit d207d9c

Browse files
[fix][broker] Guard AsyncTokenBucket against long overflow (#25262)
Co-authored-by: Lari Hotari <lhotari@apache.org>
1 parent 82846e5 commit d207d9c

File tree

2 files changed

+96
-5
lines changed

2 files changed

+96
-5
lines changed

pulsar-broker/src/main/java/org/apache/pulsar/broker/qos/AsyncTokenBucket.java

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.pulsar.broker.qos;
2121

22+
import java.math.BigInteger;
2223
import java.util.concurrent.TimeUnit;
2324
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
2425
import java.util.concurrent.atomic.LongAdder;
@@ -206,9 +207,10 @@ private long calculateNewTokensSinceLastUpdate(long currentNanos) {
206207
long currentRatePeriodNanos = getRatePeriodNanos();
207208
// new tokens is the amount of tokens that are created in the duration since the last update
208209
// with the configured rate
209-
newTokens = (durationNanos * currentRate) / currentRatePeriodNanos;
210+
newTokens = safeMulDivFloor(durationNanos, currentRate, currentRatePeriodNanos);
210211
// carry forward the remainder nanos so that the rounding error is eliminated
211-
long remainderNanos = durationNanos - ((newTokens * currentRatePeriodNanos) / currentRate);
212+
long consumedNanos = safeMulDivFloor(newTokens, currentRatePeriodNanos, currentRate);
213+
long remainderNanos = durationNanos >= consumedNanos ? durationNanos - consumedNanos : 0;
212214
if (remainderNanos > 0) {
213215
REMAINDER_NANOS_UPDATER.addAndGet(this, remainderNanos);
214216
}
@@ -263,13 +265,53 @@ public long calculateThrottlingDuration() {
263265
*/
264266
public long calculateThrottlingDuration(long requiredTokens) {
265267
long currentTokens = consumeTokensAndMaybeUpdateTokensBalance(0);
268+
266269
if (currentTokens >= requiredTokens) {
267270
return 0L;
268271
}
269272
// when currentTokens is negative, subtracting a negative value results in
270273
// adding the absolute value (-(-x) -> +x)
271-
long needTokens = requiredTokens - currentTokens;
272-
return (needTokens * getRatePeriodNanos()) / getRate();
274+
long needTokens;
275+
try {
276+
needTokens = Math.subtractExact(requiredTokens, currentTokens);
277+
} catch (ArithmeticException e) {
278+
needTokens = Long.MAX_VALUE;
279+
}
280+
return safeMulDivFloor(needTokens, getRatePeriodNanos(), getRate());
281+
}
282+
283+
private static long safeMulDivFloor(long multiplicand, long multiplier, long divisor) {
284+
if (multiplicand < 0 || multiplier < 0) {
285+
throw new IllegalArgumentException("multiplicand and multiplier must be >= 0");
286+
}
287+
if (divisor <= 0) {
288+
throw new IllegalArgumentException("divisor must be > 0");
289+
}
290+
if (multiplicand == 0 || multiplier == 0) {
291+
return 0;
292+
}
293+
// Fast path
294+
// Check if multiplication fits in a 64-bit value
295+
// Math.multiplyHigh is intrinsified by the JVM (single mulq/mul instruction),
296+
// avoiding the cost of a division-based overflow check.
297+
// It returns the upper 64 bits of the full 128-bit multiplication result.
298+
// When the result is 0, the product fits in 64 bits.
299+
if (Math.multiplyHigh(multiplicand, multiplier) == 0) {
300+
long product = multiplicand * multiplier;
301+
if (product >= 0) {
302+
// product fits in signed 64-bit
303+
return product / divisor;
304+
}
305+
// product is in [2^63, 2^64): fits unsigned but not signed
306+
long result = Long.divideUnsigned(product, divisor);
307+
// cap at Long.MAX_VALUE if result itself overflows signed long
308+
return result >= 0 ? result : Long.MAX_VALUE;
309+
}
310+
// Fallback to BigInteger division
311+
BigInteger result = BigInteger.valueOf(multiplicand)
312+
.multiply(BigInteger.valueOf(multiplier))
313+
.divide(BigInteger.valueOf(divisor));
314+
return result.bitLength() < Long.SIZE ? result.longValue() : Long.MAX_VALUE;
273315
}
274316

275317
/**

pulsar-broker/src/test/java/org/apache/pulsar/broker/qos/AsyncTokenBucketTest.java

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.concurrent.TimeUnit;
2525
import java.util.concurrent.atomic.AtomicLong;
2626
import org.testng.annotations.BeforeMethod;
27+
import org.testng.annotations.DataProvider;
2728
import org.testng.annotations.Test;
2829

2930

@@ -195,4 +196,52 @@ void shouldHandleEventualConsistency() {
195196
// iteration, the tokens should be equal to the initial tokens
196197
.isEqualTo(initialTokens);
197198
}
198-
}
199+
200+
@DataProvider(name = "largeRates")
201+
public Object[][] largeRates() {
202+
return new Object[][]{
203+
{500_000_000L},
204+
{980_000_000L},
205+
{1_000_000_000L},
206+
{1_500_000_000L},
207+
{2_000_000_000L},
208+
{Long.MAX_VALUE / 100L},
209+
{Long.MAX_VALUE / 10L},
210+
{Long.MAX_VALUE / 9L},
211+
{Long.MAX_VALUE}
212+
};
213+
}
214+
215+
@Test(dataProvider = "largeRates")
216+
void shouldRefillTokensWithoutOverflowForLargeRateAnd10sPeriod(long rate) {
217+
long ratePeriodNanos = TimeUnit.SECONDS.toNanos(10);
218+
asyncTokenBucket =
219+
AsyncTokenBucket.builder()
220+
.rate(rate)
221+
.ratePeriodNanos(ratePeriodNanos)
222+
.addTokensResolutionNanos(ratePeriodNanos)
223+
.initialTokens(0)
224+
.clock(clockSource)
225+
.build();
226+
227+
incrementSeconds(10);
228+
incrementMillis(1);
229+
230+
assertEquals(asyncTokenBucket.getTokens(), rate);
231+
}
232+
233+
@Test
234+
void shouldCalculateThrottlingDurationWithoutOverflowForLargeNeedTokens() {
235+
asyncTokenBucket =
236+
AsyncTokenBucket.builder()
237+
.rate(1)
238+
.ratePeriodNanos(TimeUnit.SECONDS.toNanos(10))
239+
.initialTokens(0)
240+
.clock(clockSource)
241+
.build();
242+
asyncTokenBucket.consumeTokens(1);
243+
244+
long throttlingDuration = asyncTokenBucket.calculateThrottlingDuration(1_000_000_000L);
245+
assertEquals(throttlingDuration, Long.MAX_VALUE);
246+
}
247+
}

0 commit comments

Comments
 (0)