Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.pulsar.broker.qos;

import java.math.BigInteger;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.LongAdder;
Expand Down Expand Up @@ -206,9 +207,10 @@ private long calculateNewTokensSinceLastUpdate(long currentNanos) {
long currentRatePeriodNanos = getRatePeriodNanos();
// new tokens is the amount of tokens that are created in the duration since the last update
// with the configured rate
newTokens = (durationNanos * currentRate) / currentRatePeriodNanos;
newTokens = safeMulDivFloor(durationNanos, currentRate, currentRatePeriodNanos);
// carry forward the remainder nanos so that the rounding error is eliminated
long remainderNanos = durationNanos - ((newTokens * currentRatePeriodNanos) / currentRate);
long consumedNanos = safeMulDivFloor(newTokens, currentRatePeriodNanos, currentRate);
long remainderNanos = durationNanos >= consumedNanos ? durationNanos - consumedNanos : 0;
if (remainderNanos > 0) {
REMAINDER_NANOS_UPDATER.addAndGet(this, remainderNanos);
}
Expand Down Expand Up @@ -263,13 +265,53 @@ public long calculateThrottlingDuration() {
*/
public long calculateThrottlingDuration(long requiredTokens) {
long currentTokens = consumeTokensAndMaybeUpdateTokensBalance(0);

if (currentTokens >= requiredTokens) {
return 0L;
}
// when currentTokens is negative, subtracting a negative value results in
// adding the absolute value (-(-x) -> +x)
long needTokens = requiredTokens - currentTokens;
return (needTokens * getRatePeriodNanos()) / getRate();
long needTokens;
try {
needTokens = Math.subtractExact(requiredTokens, currentTokens);
} catch (ArithmeticException e) {
needTokens = Long.MAX_VALUE;
}
return safeMulDivFloor(needTokens, getRatePeriodNanos(), getRate());
}

private static long safeMulDivFloor(long multiplicand, long multiplier, long divisor) {
if (multiplicand < 0 || multiplier < 0) {
throw new IllegalArgumentException("multiplicand and multiplier must be >= 0");
}
if (divisor <= 0) {
throw new IllegalArgumentException("divisor must be > 0");
}
if (multiplicand == 0 || multiplier == 0) {
return 0;
}
// Fast path
// Check if multiplication fits in a 64-bit value
// Math.multiplyHigh is intrinsified by the JVM (single mulq/mul instruction),
// avoiding the cost of a division-based overflow check.
// It returns the upper 64 bits of the full 128-bit multiplication result.
// When the result is 0, the product fits in 64 bits.
if (Math.multiplyHigh(multiplicand, multiplier) == 0) {
long product = multiplicand * multiplier;
if (product >= 0) {
// product fits in signed 64-bit
return product / divisor;
}
// product is in [2^63, 2^64): fits unsigned but not signed
long result = Long.divideUnsigned(product, divisor);
// cap at Long.MAX_VALUE if result itself overflows signed long
return result >= 0 ? result : Long.MAX_VALUE;
}
// Fallback to BigInteger division
BigInteger result = BigInteger.valueOf(multiplicand)
.multiply(BigInteger.valueOf(multiplier))
.divide(BigInteger.valueOf(divisor));
return result.bitLength() < Long.SIZE ? result.longValue() : Long.MAX_VALUE;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;


Expand Down Expand Up @@ -195,4 +196,52 @@ void shouldHandleEventualConsistency() {
// iteration, the tokens should be equal to the initial tokens
.isEqualTo(initialTokens);
}
}

@DataProvider(name = "largeRates")
public Object[][] largeRates() {
return new Object[][]{
{500_000_000L},
{980_000_000L},
{1_000_000_000L},
{1_500_000_000L},
{2_000_000_000L},
{Long.MAX_VALUE / 100L},
{Long.MAX_VALUE / 10L},
{Long.MAX_VALUE / 9L},
{Long.MAX_VALUE}
};
}

@Test(dataProvider = "largeRates")
void shouldRefillTokensWithoutOverflowForLargeRateAnd10sPeriod(long rate) {
long ratePeriodNanos = TimeUnit.SECONDS.toNanos(10);
asyncTokenBucket =
AsyncTokenBucket.builder()
.rate(rate)
.ratePeriodNanos(ratePeriodNanos)
.addTokensResolutionNanos(ratePeriodNanos)
.initialTokens(0)
.clock(clockSource)
.build();

incrementSeconds(10);
incrementMillis(1);

assertEquals(asyncTokenBucket.getTokens(), rate);
}

@Test
void shouldCalculateThrottlingDurationWithoutOverflowForLargeNeedTokens() {
asyncTokenBucket =
AsyncTokenBucket.builder()
.rate(1)
.ratePeriodNanos(TimeUnit.SECONDS.toNanos(10))
.initialTokens(0)
.clock(clockSource)
.build();
asyncTokenBucket.consumeTokens(1);

long throttlingDuration = asyncTokenBucket.calculateThrottlingDuration(1_000_000_000L);
assertEquals(throttlingDuration, Long.MAX_VALUE);
}
}
Loading