Skip to content

Commit 66a9c9e

Browse files
authored
fix: Fix the race condition in decay average (#850)
* fix: Fix the race condition in decay average * fix format * fix * remove initial condition * update * code review * update * use clock and don't decay mean * merge getDecay and getWeight * update * update
1 parent 32284d2 commit 66a9c9e

File tree

2 files changed

+78
-71
lines changed

2 files changed

+78
-71
lines changed

google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStats.java

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,48 +15,50 @@
1515
*/
1616
package com.google.cloud.bigtable.data.v2.stub;
1717

18+
import com.google.api.core.ApiClock;
19+
import com.google.api.core.InternalApi;
20+
import com.google.api.core.NanoClock;
1821
import com.google.api.gax.batching.FlowController;
19-
import com.google.common.annotations.VisibleForTesting;
20-
import com.google.common.base.Preconditions;
2122
import java.util.concurrent.TimeUnit;
2223
import java.util.concurrent.atomic.AtomicLong;
2324

2425
/**
2526
* Records stats used in dynamic flow control, the decaying average of recorded latencies and the
2627
* last timestamp when the thresholds in {@link FlowController} are updated.
28+
*
29+
* <pre>Exponential decaying average = weightedSum / weightedCount, where
30+
* weightedSum(n) = weight(n) * value(n) + weightedSum(n - 1)
31+
* weightedCount(n) = weight(n) + weightedCount(n - 1),
32+
* and weight(n) grows exponentially over elapsed time. Biased to the past 5 minutes.
2733
*/
2834
final class DynamicFlowControlStats {
2935

30-
private static final double DEFAULT_DECAY_CONSTANT = 0.015; // Biased to the past 5 minutes
36+
// Biased to the past 5 minutes (300 seconds), e^(-decay_constant * 300) = 0.01, decay_constant ~=
37+
// 0.015
38+
private static final double DEFAULT_DECAY_CONSTANT = 0.015;
39+
// Update decay cycle start time every 15 minutes so the values won't be infinite
40+
private static final long DECAY_CYCLE_SECOND = TimeUnit.MINUTES.toSeconds(15);
3141

32-
private AtomicLong lastAdjustedTimestampMs;
33-
private DecayingAverage meanLatency;
42+
private final AtomicLong lastAdjustedTimestampMs;
43+
private final DecayingAverage meanLatency;
3444

3545
DynamicFlowControlStats() {
36-
this(DEFAULT_DECAY_CONSTANT);
46+
this(DEFAULT_DECAY_CONSTANT, NanoClock.getDefaultClock());
3747
}
3848

39-
DynamicFlowControlStats(double decayConstant) {
49+
@InternalApi("visible for testing")
50+
DynamicFlowControlStats(double decayConstant, ApiClock clock) {
4051
this.lastAdjustedTimestampMs = new AtomicLong(0);
41-
this.meanLatency = new DecayingAverage(decayConstant);
52+
this.meanLatency = new DecayingAverage(decayConstant, clock);
4253
}
4354

4455
void updateLatency(long latency) {
45-
updateLatency(latency, System.currentTimeMillis());
46-
}
47-
48-
@VisibleForTesting
49-
void updateLatency(long latency, long timestampMs) {
50-
meanLatency.update(latency, timestampMs);
56+
meanLatency.update(latency);
5157
}
5258

59+
/** Return the mean calculated from the last update, will not decay over time. */
5360
double getMeanLatency() {
54-
return getMeanLatency(System.currentTimeMillis());
55-
}
56-
57-
@VisibleForTesting
58-
double getMeanLatency(long timestampMs) {
59-
return meanLatency.getMean(timestampMs);
61+
return meanLatency.getMean();
6062
}
6163

6264
public long getLastAdjustedTimestampMs() {
@@ -71,46 +73,45 @@ private class DecayingAverage {
7173
private double decayConstant;
7274
private double mean;
7375
private double weightedCount;
74-
private AtomicLong lastUpdateTimeInSecond;
76+
private long decayCycleStartEpoch;
77+
private final ApiClock clock;
7578

76-
DecayingAverage(double decayConstant) {
79+
DecayingAverage(double decayConstant, ApiClock clock) {
7780
this.decayConstant = decayConstant;
7881
this.mean = 0.0;
7982
this.weightedCount = 0.0;
80-
this.lastUpdateTimeInSecond = new AtomicLong(0);
83+
this.clock = clock;
84+
this.decayCycleStartEpoch = TimeUnit.MILLISECONDS.toSeconds(clock.millisTime());
8185
}
8286

83-
synchronized void update(long value, long timestampMs) {
84-
long now = TimeUnit.MILLISECONDS.toSeconds(timestampMs);
85-
Preconditions.checkArgument(
86-
now >= lastUpdateTimeInSecond.get(), "can't update an event in the past");
87-
if (lastUpdateTimeInSecond.get() == 0) {
88-
lastUpdateTimeInSecond.set(now);
89-
mean = value;
90-
weightedCount = 1;
91-
} else {
92-
long prev = lastUpdateTimeInSecond.getAndSet(now);
93-
long elapsed = now - prev;
94-
double alpha = getAlpha(elapsed);
95-
// Exponential moving average = weightedSum / weightedCount, where
96-
// weightedSum(n) = value + alpha * weightedSum(n - 1)
97-
// weightedCount(n) = 1 + alpha * weightedCount(n - 1)
98-
// Using weighted count in case the sum overflows
99-
mean =
100-
mean * ((weightedCount * alpha) / (weightedCount * alpha + 1))
101-
+ value / (weightedCount * alpha + 1);
102-
weightedCount = weightedCount * alpha + 1;
103-
}
87+
synchronized void update(long value) {
88+
long now = TimeUnit.MILLISECONDS.toSeconds(clock.millisTime());
89+
double weight = getWeight(now);
90+
// Using weighted count in case the sum overflows.
91+
mean =
92+
mean * (weightedCount / (weightedCount + weight))
93+
+ weight * value / (weightedCount + weight);
94+
weightedCount += weight;
10495
}
10596

106-
double getMean(long timestampMs) {
107-
long timestampSecs = TimeUnit.MILLISECONDS.toSeconds(timestampMs);
108-
long elapsed = timestampSecs - lastUpdateTimeInSecond.get();
109-
return mean * getAlpha(Math.max(0, elapsed));
97+
double getMean() {
98+
return mean;
11099
}
111100

112-
private double getAlpha(long elapsedSecond) {
113-
return Math.exp(-decayConstant * elapsedSecond);
101+
private double getWeight(long now) {
102+
long elapsedSecond = now - decayCycleStartEpoch;
103+
double weight = Math.exp(decayConstant * elapsedSecond);
104+
// Decay mean, weightedCount and reset decay cycle start epoch every 15 minutes, so the
105+
// values won't be infinite
106+
if (elapsedSecond > DECAY_CYCLE_SECOND) {
107+
mean /= weight;
108+
weightedCount /= weight;
109+
decayCycleStartEpoch = now;
110+
// After resetting start time, weight = e^0 = 1
111+
return 1;
112+
} else {
113+
return weight;
114+
}
114115
}
115116
}
116117
}

google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStatsTest.java

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,50 +17,56 @@
1717

1818
import static com.google.common.truth.Truth.assertThat;
1919

20+
import com.google.api.core.ApiClock;
2021
import java.util.LinkedList;
2122
import java.util.List;
2223
import java.util.concurrent.ExecutionException;
2324
import java.util.concurrent.ExecutorService;
2425
import java.util.concurrent.Executors;
2526
import java.util.concurrent.Future;
2627
import java.util.concurrent.TimeUnit;
28+
import org.junit.Rule;
2729
import org.junit.Test;
2830
import org.junit.runner.RunWith;
2931
import org.junit.runners.JUnit4;
32+
import org.mockito.Mock;
33+
import org.mockito.Mockito;
34+
import org.mockito.junit.MockitoJUnit;
35+
import org.mockito.junit.MockitoRule;
3036

3137
@RunWith(JUnit4.class)
3238
public class DynamicFlowControlStatsTest {
3339

40+
@Rule public final MockitoRule rule = MockitoJUnit.rule();
41+
42+
@Mock private ApiClock clock;
43+
3444
@Test
3545
public void testUpdate() {
36-
DynamicFlowControlStats stats = new DynamicFlowControlStats();
37-
long now = System.currentTimeMillis();
3846

39-
stats.updateLatency(10, now);
40-
assertThat(stats.getMeanLatency(now)).isEqualTo(10);
41-
42-
stats.updateLatency(10, now);
43-
stats.updateLatency(10, now);
44-
assertThat(stats.getMeanLatency(now)).isEqualTo(10);
47+
Mockito.when(clock.millisTime()).thenReturn(0L);
48+
DynamicFlowControlStats stats = new DynamicFlowControlStats(0.015, clock);
49+
stats.updateLatency(10);
50+
assertThat(stats.getMeanLatency()).isEqualTo(10);
51+
stats.updateLatency(10);
52+
stats.updateLatency(10);
53+
assertThat(stats.getMeanLatency()).isEqualTo(10);
4554

4655
// In five minutes the previous latency should be decayed to under 1. And the new average should
4756
// be very close to 20
48-
long fiveMinutesLater = now + TimeUnit.MINUTES.toMillis(5);
49-
assertThat(stats.getMeanLatency(fiveMinutesLater)).isLessThan(1);
50-
stats.updateLatency(20, fiveMinutesLater);
51-
assertThat(stats.getMeanLatency(fiveMinutesLater)).isGreaterThan(19);
52-
assertThat(stats.getMeanLatency(fiveMinutesLater)).isLessThan(20);
53-
54-
long aDayLater = now + TimeUnit.HOURS.toMillis(24);
55-
assertThat(stats.getMeanLatency(aDayLater)).isZero();
57+
Mockito.when(clock.millisTime()).thenReturn(TimeUnit.MINUTES.toMillis(5));
58+
stats.updateLatency(20);
59+
assertThat(stats.getMeanLatency()).isGreaterThan(19);
60+
assertThat(stats.getMeanLatency()).isLessThan(20);
5661

57-
long timestamp = aDayLater;
62+
// After a day
63+
long aDay = TimeUnit.DAYS.toMillis(1);
5864
for (int i = 0; i < 10; i++) {
59-
timestamp += TimeUnit.SECONDS.toMillis(i);
60-
stats.updateLatency(i, timestamp);
65+
Mockito.when(clock.millisTime()).thenReturn(aDay + TimeUnit.SECONDS.toMillis(i));
66+
stats.updateLatency(i);
6167
}
62-
assertThat(stats.getMeanLatency(timestamp)).isGreaterThan(4.5);
63-
assertThat(stats.getMeanLatency(timestamp)).isLessThan(6);
68+
assertThat(stats.getMeanLatency()).isGreaterThan(4.5);
69+
assertThat(stats.getMeanLatency()).isLessThan(6);
6470
}
6571

6672
@Test(timeout = 1000)

0 commit comments

Comments
 (0)