|
19 | 19 |
|
20 | 20 | package org.apache.pulsar.broker.qos; |
21 | 21 |
|
| 22 | +import java.math.BigInteger; |
22 | 23 | import java.util.concurrent.TimeUnit; |
23 | 24 | import java.util.concurrent.atomic.AtomicLongFieldUpdater; |
24 | 25 | import java.util.concurrent.atomic.LongAdder; |
@@ -206,9 +207,10 @@ private long calculateNewTokensSinceLastUpdate(long currentNanos) { |
206 | 207 | long currentRatePeriodNanos = getRatePeriodNanos(); |
207 | 208 | // new tokens is the amount of tokens that are created in the duration since the last update |
208 | 209 | // with the configured rate |
209 | | - newTokens = (durationNanos * currentRate) / currentRatePeriodNanos; |
| 210 | + newTokens = safeMulDivFloor(durationNanos, currentRate, currentRatePeriodNanos); |
210 | 211 | // carry forward the remainder nanos so that the rounding error is eliminated |
211 | | - long remainderNanos = durationNanos - ((newTokens * currentRatePeriodNanos) / currentRate); |
| 212 | + long consumedNanos = safeMulDivFloor(newTokens, currentRatePeriodNanos, currentRate); |
| 213 | + long remainderNanos = durationNanos >= consumedNanos ? durationNanos - consumedNanos : 0; |
212 | 214 | if (remainderNanos > 0) { |
213 | 215 | REMAINDER_NANOS_UPDATER.addAndGet(this, remainderNanos); |
214 | 216 | } |
@@ -263,13 +265,53 @@ public long calculateThrottlingDuration() { |
263 | 265 | */ |
264 | 266 | public long calculateThrottlingDuration(long requiredTokens) { |
265 | 267 | long currentTokens = consumeTokensAndMaybeUpdateTokensBalance(0); |
| 268 | + |
266 | 269 | if (currentTokens >= requiredTokens) { |
267 | 270 | return 0L; |
268 | 271 | } |
269 | 272 | // when currentTokens is negative, subtracting a negative value results in |
270 | 273 | // adding the absolute value (-(-x) -> +x) |
271 | | - long needTokens = requiredTokens - currentTokens; |
272 | | - return (needTokens * getRatePeriodNanos()) / getRate(); |
| 274 | + long needTokens; |
| 275 | + try { |
| 276 | + needTokens = Math.subtractExact(requiredTokens, currentTokens); |
| 277 | + } catch (ArithmeticException e) { |
| 278 | + needTokens = Long.MAX_VALUE; |
| 279 | + } |
| 280 | + return safeMulDivFloor(needTokens, getRatePeriodNanos(), getRate()); |
| 281 | + } |
| 282 | + |
| 283 | + private static long safeMulDivFloor(long multiplicand, long multiplier, long divisor) { |
| 284 | + if (multiplicand < 0 || multiplier < 0) { |
| 285 | + throw new IllegalArgumentException("multiplicand and multiplier must be >= 0"); |
| 286 | + } |
| 287 | + if (divisor <= 0) { |
| 288 | + throw new IllegalArgumentException("divisor must be > 0"); |
| 289 | + } |
| 290 | + if (multiplicand == 0 || multiplier == 0) { |
| 291 | + return 0; |
| 292 | + } |
| 293 | + // Fast path |
| 294 | + // Check if multiplication fits in a 64-bit value |
| 295 | + // Math.multiplyHigh is intrinsified by the JVM (single mulq/mul instruction), |
| 296 | + // avoiding the cost of a division-based overflow check. |
| 297 | + // It returns the upper 64 bits of the full 128-bit multiplication result. |
| 298 | + // When the result is 0, the product fits in 64 bits. |
| 299 | + if (Math.multiplyHigh(multiplicand, multiplier) == 0) { |
| 300 | + long product = multiplicand * multiplier; |
| 301 | + if (product >= 0) { |
| 302 | + // product fits in signed 64-bit |
| 303 | + return product / divisor; |
| 304 | + } |
| 305 | + // product is in [2^63, 2^64): fits unsigned but not signed |
| 306 | + long result = Long.divideUnsigned(product, divisor); |
| 307 | + // cap at Long.MAX_VALUE if result itself overflows signed long |
| 308 | + return result >= 0 ? result : Long.MAX_VALUE; |
| 309 | + } |
| 310 | + // Fallback to BigInteger division |
| 311 | + BigInteger result = BigInteger.valueOf(multiplicand) |
| 312 | + .multiply(BigInteger.valueOf(multiplier)) |
| 313 | + .divide(BigInteger.valueOf(divisor)); |
| 314 | + return result.bitLength() < Long.SIZE ? result.longValue() : Long.MAX_VALUE; |
273 | 315 | } |
274 | 316 |
|
275 | 317 | /** |
|
0 commit comments