Skip to content

Commit 663a44b

Browse files
authored
RATIS-2245. Ratis should wait for all apply transaction futures before taking snapshot and group remove (#1218)
1 parent a9ebdb6 commit 663a44b

File tree

2 files changed

+119
-47
lines changed

2 files changed

+119
-47
lines changed

ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
import org.slf4j.LoggerFactory;
3838

3939
import java.io.IOException;
40-
import java.util.ArrayList;
41-
import java.util.List;
4240
import java.util.Objects;
4341
import java.util.Optional;
4442
import java.util.concurrent.CompletableFuture;
@@ -182,19 +180,20 @@ public String toString() {
182180

183181
@Override
184182
public void run() {
183+
CompletableFuture<Void> applyLogFutures = CompletableFuture.completedFuture(null);
185184
for(; state != State.STOP; ) {
186185
try {
187-
waitForCommit();
186+
waitForCommit(applyLogFutures);
188187

189188
if (state == State.RELOAD) {
190189
reload();
191190
}
192191

193-
final MemoizedSupplier<List<CompletableFuture<Message>>> futures = applyLog();
194-
checkAndTakeSnapshot(futures);
192+
applyLogFutures = applyLog(applyLogFutures);
193+
checkAndTakeSnapshot(applyLogFutures);
195194

196195
if (shouldStop()) {
197-
checkAndTakeSnapshot(futures);
196+
applyLogFutures.get();
198197
stop();
199198
}
200199
} catch (Throwable t) {
@@ -210,14 +209,14 @@ public void run() {
210209
}
211210
}
212211

213-
private void waitForCommit() throws InterruptedException {
212+
private void waitForCommit(CompletableFuture<?> applyLogFutures) throws InterruptedException, ExecutionException {
214213
// When a peer starts, the committed is initialized to 0.
215214
// It will be updated only after the leader contacts other peers.
216215
// Thus it is possible to have applied > committed initially.
217216
final long applied = getLastAppliedIndex();
218217
for(; applied >= raftLog.getLastCommittedIndex() && state == State.RUNNING && !shouldStop(); ) {
219218
if (server.getSnapshotRequestHandler().shouldTriggerTakingSnapshot()) {
220-
takeSnapshot();
219+
takeSnapshot(applyLogFutures);
221220
}
222221
if (awaitForSignal.await(100, TimeUnit.MILLISECONDS)) {
223222
return;
@@ -239,8 +238,7 @@ private void reload() throws IOException {
239238
state = State.RUNNING;
240239
}
241240

242-
private MemoizedSupplier<List<CompletableFuture<Message>>> applyLog() throws RaftLogIOException {
243-
final MemoizedSupplier<List<CompletableFuture<Message>>> futures = MemoizedSupplier.valueOf(ArrayList::new);
241+
private CompletableFuture<Void> applyLog(CompletableFuture<Void> applyLogFutures) throws RaftLogIOException {
244242
final long committed = raftLog.getLastCommittedIndex();
245243
for(long applied; (applied = getLastAppliedIndex()) < committed && state == State.RUNNING && !shouldStop(); ) {
246244
final long nextIndex = applied + 1;
@@ -263,7 +261,12 @@ private MemoizedSupplier<List<CompletableFuture<Message>>> applyLog() throws Raf
263261
final long incremented = appliedIndex.incrementAndGet(debugIndexChange);
264262
Preconditions.assertTrue(incremented == nextIndex);
265263
if (f != null) {
266-
futures.get().add(f);
264+
CompletableFuture<Message> exceptionHandledFuture = f.exceptionally(ex -> {
265+
LOG.error("Exception while {}: applying txn index={}, nextLog={}", this, nextIndex,
266+
LogProtoUtils.toLogEntryString(entry), ex);
267+
return null;
268+
});
269+
applyLogFutures = applyLogFutures.thenCombine(exceptionHandledFuture, (v, message) -> null);
267270
f.thenAccept(m -> notifyAppliedIndex(incremented));
268271
} else {
269272
notifyAppliedIndex(incremented);
@@ -272,23 +275,20 @@ private MemoizedSupplier<List<CompletableFuture<Message>>> applyLog() throws Raf
272275
next.release();
273276
}
274277
}
275-
return futures;
278+
return applyLogFutures;
276279
}
277280

278-
private void checkAndTakeSnapshot(MemoizedSupplier<List<CompletableFuture<Message>>> futures)
281+
private void checkAndTakeSnapshot(CompletableFuture<?> futures)
279282
throws ExecutionException, InterruptedException {
280283
// check if need to trigger a snapshot
281284
if (shouldTakeSnapshot()) {
282-
if (futures.isInitialized()) {
283-
JavaUtils.allOf(futures.get()).get();
284-
}
285-
286-
takeSnapshot();
285+
takeSnapshot(futures);
287286
}
288287
}
289288

290-
private void takeSnapshot() {
289+
private void takeSnapshot(CompletableFuture<?> applyLogFutures) throws ExecutionException, InterruptedException {
291290
final long i;
291+
applyLogFutures.get();
292292
try {
293293
try(UncheckedAutoCloseable ignored = Timekeeper.start(stateMachineMetrics.get().getTakeSnapshotTimer())) {
294294
i = stateMachine.takeSnapshot();

ratis-server/src/test/java/org/apache/ratis/server/impl/StateMachineShutdownTests.java

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,47 +28,106 @@
2828
import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing;
2929
import org.apache.ratis.statemachine.StateMachine;
3030
import org.apache.ratis.statemachine.TransactionContext;
31-
import org.junit.Assert;
32-
import org.junit.Test;
33-
34-
import java.util.concurrent.CompletableFuture;
31+
import org.junit.*;
32+
import org.mockito.MockedStatic;
33+
import org.mockito.Mockito;
34+
import org.slf4j.Logger;
35+
import org.slf4j.LoggerFactory;
3536

37+
import java.util.*;
38+
import java.util.concurrent.*;
39+
import java.util.concurrent.atomic.AtomicLong;
3640

3741
public abstract class StateMachineShutdownTests<CLUSTER extends MiniRaftCluster>
3842
extends BaseTest
3943
implements MiniRaftCluster.Factory.Get<CLUSTER> {
40-
44+
public static Logger LOG = LoggerFactory.getLogger(StateMachineUpdater.class);
45+
private static MockedStatic<CompletableFuture> mocked;
4146
protected static class StateMachineWithConditionalWait extends
4247
SimpleStateMachine4Testing {
48+
boolean unblockAllTxns = false;
49+
final Set<Long> blockTxns = ConcurrentHashMap.newKeySet();
50+
private final ExecutorService executor = Executors.newFixedThreadPool(10);
51+
public static Map<Long, Set<CompletableFuture<Message>>> futures = new ConcurrentHashMap<>();
52+
public static Map<RaftPeerId, AtomicLong> numTxns = new ConcurrentHashMap<>();
53+
private final Map<Long, Long> appliedTxns = new ConcurrentHashMap<>();
54+
55+
private synchronized void updateTxns() {
56+
long appliedIndex = this.getLastAppliedTermIndex().getIndex() + 1;
57+
Long appliedTerm = null;
58+
while (appliedTxns.containsKey(appliedIndex)) {
59+
appliedTerm = appliedTxns.remove(appliedIndex);
60+
appliedIndex += 1;
61+
}
62+
if (appliedTerm != null) {
63+
updateLastAppliedTermIndex(appliedTerm, appliedIndex - 1);
64+
}
65+
}
4366

44-
private final Long objectToWait = 0L;
45-
volatile boolean blockOnApply = true;
67+
@Override
68+
public void notifyTermIndexUpdated(long term, long index) {
69+
appliedTxns.put(index, term);
70+
updateTxns();
71+
}
4672

4773
@Override
4874
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
49-
if (blockOnApply) {
50-
synchronized (objectToWait) {
51-
try {
52-
objectToWait.wait();
53-
} catch (InterruptedException e) {
54-
Thread.currentThread().interrupt();
55-
throw new RuntimeException();
75+
final RaftProtos.LogEntryProto entry = trx.getLogEntryUnsafe();
76+
77+
CompletableFuture<Message> future = new CompletableFuture<>();
78+
futures.computeIfAbsent(Thread.currentThread().getId(), k -> new HashSet<>()).add(future);
79+
executor.submit(() -> {
80+
synchronized (blockTxns) {
81+
if (!unblockAllTxns) {
82+
blockTxns.add(entry.getIndex());
83+
}
84+
while (!unblockAllTxns && blockTxns.contains(entry.getIndex())) {
85+
try {
86+
blockTxns.wait(10000);
87+
} catch (InterruptedException e) {
88+
throw new RuntimeException(e);
89+
}
5690
}
5791
}
92+
numTxns.computeIfAbsent(getId(), (k) -> new AtomicLong()).incrementAndGet();
93+
appliedTxns.put(entry.getIndex(), entry.getTerm());
94+
updateTxns();
95+
future.complete(new RaftTestUtil.SimpleMessage("done"));
96+
});
97+
return future;
98+
}
99+
100+
public void unBlockApplyTxn(long txnId) {
101+
synchronized (blockTxns) {
102+
blockTxns.remove(txnId);
103+
blockTxns.notifyAll();
58104
}
59-
final RaftProtos.LogEntryProto entry = trx.getLogEntryUnsafe();
60-
updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
61-
return CompletableFuture.completedFuture(new RaftTestUtil.SimpleMessage("done"));
62105
}
63106

64-
public void unBlockApplyTxn() {
65-
blockOnApply = false;
66-
synchronized (objectToWait) {
67-
objectToWait.notifyAll();
107+
public void unblockAllTxns() {
108+
unblockAllTxns = true;
109+
synchronized (blockTxns) {
110+
for (Long txnId : blockTxns) {
111+
blockTxns.remove(txnId);
112+
}
113+
blockTxns.notifyAll();
68114
}
69115
}
70116
}
71117

118+
@Before
119+
public void setup() {
120+
mocked = Mockito.mockStatic(CompletableFuture.class, Mockito.CALLS_REAL_METHODS);
121+
}
122+
123+
@After
124+
public void tearDownClass() {
125+
if (mocked != null) {
126+
mocked.close();
127+
}
128+
129+
}
130+
72131
@Test
73132
public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {
74133
final RaftProperties prop = getProperties();
@@ -82,10 +141,9 @@ public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {
82141

83142
//Unblock leader and one follower
84143
((StateMachineWithConditionalWait)leader.getStateMachine())
85-
.unBlockApplyTxn();
144+
.unblockAllTxns();
86145
((StateMachineWithConditionalWait)cluster.
87-
getFollowers().get(0).getStateMachine()).unBlockApplyTxn();
88-
146+
getFollowers().get(0).getStateMachine()).unblockAllTxns();
89147
cluster.getLeaderAndSendFirstMessage(true);
90148

91149
try (final RaftClient client = cluster.createClient(leaderId)) {
@@ -107,16 +165,30 @@ public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {
107165
final Thread t = new Thread(secondFollower::close);
108166
t.start();
109167

110-
// The second follower should still be blocked in apply transaction
111-
Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < logIndex);
168+
112169

113170
// Now unblock the second follower
114-
((StateMachineWithConditionalWait) secondFollower.getStateMachine())
115-
.unBlockApplyTxn();
171+
long minIndex = ((StateMachineWithConditionalWait) secondFollower.getStateMachine()).blockTxns.stream()
172+
.min(Comparator.naturalOrder()).get();
173+
Assert.assertEquals(2, StateMachineWithConditionalWait.numTxns.values().stream()
174+
.filter(val -> val.get() == 3).count());
175+
// The second follower should still be blocked in apply transaction
176+
Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < minIndex);
177+
for (long index : ((StateMachineWithConditionalWait) secondFollower.getStateMachine()).blockTxns) {
178+
if (minIndex != index) {
179+
((StateMachineWithConditionalWait) secondFollower.getStateMachine()).unBlockApplyTxn(index);
180+
}
181+
}
182+
Assert.assertEquals(2, StateMachineWithConditionalWait.numTxns.values().stream()
183+
.filter(val -> val.get() == 3).count());
184+
Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < minIndex);
185+
((StateMachineWithConditionalWait) secondFollower.getStateMachine()).unBlockApplyTxn(minIndex);
116186

117187
// Now wait for the thread
118188
t.join(5000);
119189
Assert.assertEquals(logIndex, secondFollower.getInfo().getLastAppliedIndex());
190+
Assert.assertEquals(3, StateMachineWithConditionalWait.numTxns.values().stream()
191+
.filter(val -> val.get() == 3).count());
120192

121193
cluster.shutdown();
122194
}

0 commit comments

Comments
 (0)