2828import org .apache .ratis .statemachine .impl .SimpleStateMachine4Testing ;
2929import org .apache .ratis .statemachine .StateMachine ;
3030import org .apache .ratis .statemachine .TransactionContext ;
31- import org .junit .Assert ;
32- import org .junit .Test ;
33-
34- import java .util .concurrent .CompletableFuture ;
31+ import org .junit .*;
32+ import org .mockito .MockedStatic ;
33+ import org .mockito .Mockito ;
34+ import org .slf4j .Logger ;
35+ import org .slf4j .LoggerFactory ;
3536
37+ import java .util .*;
38+ import java .util .concurrent .*;
39+ import java .util .concurrent .atomic .AtomicLong ;
3640
3741public abstract class StateMachineShutdownTests <CLUSTER extends MiniRaftCluster >
3842 extends BaseTest
3943 implements MiniRaftCluster .Factory .Get <CLUSTER > {
40-
44+ public static Logger LOG = LoggerFactory .getLogger (StateMachineUpdater .class );
45+ private static MockedStatic <CompletableFuture > mocked ;
4146 protected static class StateMachineWithConditionalWait extends
4247 SimpleStateMachine4Testing {
48+ boolean unblockAllTxns = false ;
49+ final Set <Long > blockTxns = ConcurrentHashMap .newKeySet ();
50+ private final ExecutorService executor = Executors .newFixedThreadPool (10 );
51+ public static Map <Long , Set <CompletableFuture <Message >>> futures = new ConcurrentHashMap <>();
52+ public static Map <RaftPeerId , AtomicLong > numTxns = new ConcurrentHashMap <>();
53+ private final Map <Long , Long > appliedTxns = new ConcurrentHashMap <>();
54+
55+ private synchronized void updateTxns () {
56+ long appliedIndex = this .getLastAppliedTermIndex ().getIndex () + 1 ;
57+ Long appliedTerm = null ;
58+ while (appliedTxns .containsKey (appliedIndex )) {
59+ appliedTerm = appliedTxns .remove (appliedIndex );
60+ appliedIndex += 1 ;
61+ }
62+ if (appliedTerm != null ) {
63+ updateLastAppliedTermIndex (appliedTerm , appliedIndex - 1 );
64+ }
65+ }
4366
44- private final Long objectToWait = 0L ;
45- volatile boolean blockOnApply = true ;
67+ @ Override
68+ public void notifyTermIndexUpdated (long term , long index ) {
69+ appliedTxns .put (index , term );
70+ updateTxns ();
71+ }
4672
4773 @ Override
4874 public CompletableFuture <Message > applyTransaction (TransactionContext trx ) {
49- if (blockOnApply ) {
50- synchronized (objectToWait ) {
51- try {
52- objectToWait .wait ();
53- } catch (InterruptedException e ) {
54- Thread .currentThread ().interrupt ();
55- throw new RuntimeException ();
75+ final RaftProtos .LogEntryProto entry = trx .getLogEntryUnsafe ();
76+
77+ CompletableFuture <Message > future = new CompletableFuture <>();
78+ futures .computeIfAbsent (Thread .currentThread ().getId (), k -> new HashSet <>()).add (future );
79+ executor .submit (() -> {
80+ synchronized (blockTxns ) {
81+ if (!unblockAllTxns ) {
82+ blockTxns .add (entry .getIndex ());
83+ }
84+ while (!unblockAllTxns && blockTxns .contains (entry .getIndex ())) {
85+ try {
86+ blockTxns .wait (10000 );
87+ } catch (InterruptedException e ) {
88+ throw new RuntimeException (e );
89+ }
5690 }
5791 }
92+ numTxns .computeIfAbsent (getId (), (k ) -> new AtomicLong ()).incrementAndGet ();
93+ appliedTxns .put (entry .getIndex (), entry .getTerm ());
94+ updateTxns ();
95+ future .complete (new RaftTestUtil .SimpleMessage ("done" ));
96+ });
97+ return future ;
98+ }
99+
100+ public void unBlockApplyTxn (long txnId ) {
101+ synchronized (blockTxns ) {
102+ blockTxns .remove (txnId );
103+ blockTxns .notifyAll ();
58104 }
59- final RaftProtos .LogEntryProto entry = trx .getLogEntryUnsafe ();
60- updateLastAppliedTermIndex (entry .getTerm (), entry .getIndex ());
61- return CompletableFuture .completedFuture (new RaftTestUtil .SimpleMessage ("done" ));
62105 }
63106
64- public void unBlockApplyTxn () {
65- blockOnApply = false ;
66- synchronized (objectToWait ) {
67- objectToWait .notifyAll ();
107+ public void unblockAllTxns () {
108+ unblockAllTxns = true ;
109+ synchronized (blockTxns ) {
110+ for (Long txnId : blockTxns ) {
111+ blockTxns .remove (txnId );
112+ }
113+ blockTxns .notifyAll ();
68114 }
69115 }
70116 }
71117
118+ @ Before
119+ public void setup () {
120+ mocked = Mockito .mockStatic (CompletableFuture .class , Mockito .CALLS_REAL_METHODS );
121+ }
122+
123+ @ After
124+ public void tearDownClass () {
125+ if (mocked != null ) {
126+ mocked .close ();
127+ }
128+
129+ }
130+
72131 @ Test
73132 public void testStateMachineShutdownWaitsForApplyTxn () throws Exception {
74133 final RaftProperties prop = getProperties ();
@@ -82,10 +141,9 @@ public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {
82141
83142 //Unblock leader and one follower
84143 ((StateMachineWithConditionalWait )leader .getStateMachine ())
85- . unBlockApplyTxn ();
144+ . unblockAllTxns ();
86145 ((StateMachineWithConditionalWait )cluster .
87- getFollowers ().get (0 ).getStateMachine ()).unBlockApplyTxn ();
88-
146+ getFollowers ().get (0 ).getStateMachine ()).unblockAllTxns ();
89147 cluster .getLeaderAndSendFirstMessage (true );
90148
91149 try (final RaftClient client = cluster .createClient (leaderId )) {
@@ -107,16 +165,30 @@ public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {
107165 final Thread t = new Thread (secondFollower ::close );
108166 t .start ();
109167
110- // The second follower should still be blocked in apply transaction
111- Assert .assertTrue (secondFollower .getInfo ().getLastAppliedIndex () < logIndex );
168+
112169
113170 // Now unblock the second follower
114- ((StateMachineWithConditionalWait ) secondFollower .getStateMachine ())
115- .unBlockApplyTxn ();
171+ long minIndex = ((StateMachineWithConditionalWait ) secondFollower .getStateMachine ()).blockTxns .stream ()
172+ .min (Comparator .naturalOrder ()).get ();
173+ Assert .assertEquals (2 , StateMachineWithConditionalWait .numTxns .values ().stream ()
174+ .filter (val -> val .get () == 3 ).count ());
175+ // The second follower should still be blocked in apply transaction
176+ Assert .assertTrue (secondFollower .getInfo ().getLastAppliedIndex () < minIndex );
177+ for (long index : ((StateMachineWithConditionalWait ) secondFollower .getStateMachine ()).blockTxns ) {
178+ if (minIndex != index ) {
179+ ((StateMachineWithConditionalWait ) secondFollower .getStateMachine ()).unBlockApplyTxn (index );
180+ }
181+ }
182+ Assert .assertEquals (2 , StateMachineWithConditionalWait .numTxns .values ().stream ()
183+ .filter (val -> val .get () == 3 ).count ());
184+ Assert .assertTrue (secondFollower .getInfo ().getLastAppliedIndex () < minIndex );
185+ ((StateMachineWithConditionalWait ) secondFollower .getStateMachine ()).unBlockApplyTxn (minIndex );
116186
117187 // Now wait for the thread
118188 t .join (5000 );
119189 Assert .assertEquals (logIndex , secondFollower .getInfo ().getLastAppliedIndex ());
190+ Assert .assertEquals (3 , StateMachineWithConditionalWait .numTxns .values ().stream ()
191+ .filter (val -> val .get () == 3 ).count ());
120192
121193 cluster .shutdown ();
122194 }
0 commit comments