4646import org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
4747import org .elasticsearch .common .util .concurrent .StoppableExecutorServiceWrapper ;
4848import org .elasticsearch .common .util .concurrent .ThreadContext ;
49+ import org .elasticsearch .core .Releasable ;
50+ import org .elasticsearch .core .Releasables ;
4951import org .elasticsearch .core .SuppressForbidden ;
5052import org .elasticsearch .core .TimeValue ;
5153import org .elasticsearch .core .Tuple ;
@@ -259,9 +261,42 @@ public void clusterStatePublished(ClusterState newClusterState) {
259261 assertThat (registeredActions .toString (), registeredActions , contains (MasterService .STATE_UPDATE_ACTION_NAME ));
260262 }
261263
262- public void testThreadContext () throws InterruptedException {
264+ public void testThreadContext () {
263265 try (var master = createMasterService (true )) {
264- final CountDownLatch latch = new CountDownLatch (1 );
266+
267+ master .setClusterStatePublisher ((clusterStatePublicationEvent , publishListener , ackListener ) -> {
268+ ClusterServiceUtils .setAllElapsedMillis (clusterStatePublicationEvent );
269+ try (var ignored = threadPool .getThreadContext ().newEmptyContext ()) {
270+ if (randomBoolean ()) {
271+ randomExecutor (threadPool ).execute (() -> publishListener .onResponse (null ));
272+ randomExecutor (threadPool ).execute (() -> ackListener .onCommit (TimeValue .timeValueMillis (randomInt (10000 ))));
273+ randomExecutor (threadPool ).execute (
274+ () -> ackListener .onNodeAck (
275+ clusterStatePublicationEvent .getNewState ().nodes ().getMasterNode (),
276+ randomBoolean () ? null : new RuntimeException ("simulated ack failure" )
277+ )
278+ );
279+ } else {
280+ randomExecutor (threadPool ).execute (
281+ () -> publishListener .onFailure (new FailedToCommitClusterStateException ("simulated publish failure" ))
282+ );
283+ }
284+ }
285+ });
286+
287+ final Releasable onPublishComplete ;
288+ final Releasable onAckingComplete ;
289+ final Runnable awaitComplete ;
290+ {
291+ final var publishLatch = new CountDownLatch (1 );
292+ final var ackingLatch = new CountDownLatch (1 );
293+ onPublishComplete = Releasables .assertOnce (publishLatch ::countDown );
294+ onAckingComplete = Releasables .assertOnce (ackingLatch ::countDown );
295+ awaitComplete = () -> {
296+ safeAwait (publishLatch );
297+ safeAwait (ackingLatch );
298+ };
299+ }
265300
266301 try (ThreadContext .StoredContext ignored = threadPool .getThreadContext ().stashContext ()) {
267302
@@ -272,15 +307,12 @@ public void testThreadContext() throws InterruptedException {
272307 expectedHeaders .put (copiedHeader , randomIdentifier ());
273308 }
274309 }
275-
276- final Map <String , List <String >> expectedResponseHeaders = Collections .singletonMap (
277- "testResponse" ,
278- Collections .singletonList ("testResponse" )
279- );
280310 threadPool .getThreadContext ().putHeader (expectedHeaders );
281311
282- final TimeValue ackTimeout = randomBoolean () ? TimeValue .ZERO : TimeValue .timeValueMillis (randomInt (10000 ));
283- final TimeValue masterTimeout = randomBoolean () ? TimeValue .ZERO : TimeValue .timeValueMillis (randomInt (10000 ));
312+ final Map <String , List <String >> expectedResponseHeaders = Map .of ("testResponse" , List .of (randomIdentifier ()));
313+
314+ final TimeValue ackTimeout = randomBoolean () ? TimeValue .MINUS_ONE : TimeValue .timeValueMillis (randomInt (10000 ));
315+ final TimeValue masterTimeout = randomBoolean () ? TimeValue .MINUS_ONE : TimeValue .timeValueMillis (randomInt (10000 ));
284316
285317 master .submitUnbatchedStateUpdateTask (
286318 "test" ,
@@ -289,8 +321,9 @@ public void testThreadContext() throws InterruptedException {
289321 public ClusterState execute (ClusterState currentState ) {
290322 assertTrue (threadPool .getThreadContext ().isSystemContext ());
291323 assertEquals (Collections .emptyMap (), threadPool .getThreadContext ().getHeaders ());
292- threadPool .getThreadContext ().addResponseHeader ("testResponse" , "testResponse" );
293- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
324+ expectedResponseHeaders .forEach (
325+ (name , values ) -> values .forEach (v -> threadPool .getThreadContext ().addResponseHeader (name , v ))
326+ );
294327
295328 if (randomBoolean ()) {
296329 return ClusterState .builder (currentState ).build ();
@@ -303,44 +336,44 @@ public ClusterState execute(ClusterState currentState) {
303336
304337 @ Override
305338 public void onFailure (Exception e ) {
306- assertFalse (threadPool .getThreadContext ().isSystemContext ());
307- assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
308- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
309- latch .countDown ();
339+ assertExpectedThreadContext (
340+ e instanceof ProcessClusterEventTimeoutException ? Map .of () : expectedResponseHeaders
341+ );
342+ onPublishComplete .close ();
343+ onAckingComplete .close (); // no acking takes place if publication failed
310344 }
311345
312346 @ Override
313347 public void clusterStateProcessed (ClusterState oldState , ClusterState newState ) {
314- assertFalse (threadPool .getThreadContext ().isSystemContext ());
315- assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
316- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
317- latch .countDown ();
348+ assertExpectedThreadContext (expectedResponseHeaders );
349+ onPublishComplete .close ();
318350 }
319351
320352 @ Override
321353 public void onAllNodesAcked () {
322- assertFalse (threadPool .getThreadContext ().isSystemContext ());
323- assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
324- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
325- latch .countDown ();
354+ onAckCompletion ();
326355 }
327356
328357 @ Override
329358 public void onAckFailure (Exception e ) {
330- assertFalse (threadPool .getThreadContext ().isSystemContext ());
331- assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
332- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
333- latch .countDown ();
359+ onAckCompletion ();
334360 }
335361
336362 @ Override
337363 public void onAckTimeout () {
364+ onAckCompletion ();
365+ }
366+
367+ private void onAckCompletion () {
368+ assertExpectedThreadContext (expectedResponseHeaders );
369+ onAckingComplete .close ();
370+ }
371+
372+ private void assertExpectedThreadContext (Map <String , List <String >> expectedResponseHeaders ) {
338373 assertFalse (threadPool .getThreadContext ().isSystemContext ());
339374 assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
340375 assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
341- latch .countDown ();
342376 }
343-
344377 }
345378 );
346379
@@ -349,7 +382,7 @@ public void onAckTimeout() {
349382 assertEquals (Collections .emptyMap (), threadPool .getThreadContext ().getResponseHeaders ());
350383 }
351384
352- assertTrue ( latch . await ( 10 , TimeUnit . SECONDS ) );
385+ awaitComplete . run ( );
353386 }
354387 }
355388
0 commit comments