4545import org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
4646import org .elasticsearch .common .util .concurrent .StoppableExecutorServiceWrapper ;
4747import org .elasticsearch .common .util .concurrent .ThreadContext ;
48+ import org .elasticsearch .core .Releasable ;
49+ import org .elasticsearch .core .Releasables ;
4850import org .elasticsearch .core .SuppressForbidden ;
4951import org .elasticsearch .core .TimeValue ;
5052import org .elasticsearch .core .Tuple ;
@@ -258,9 +260,42 @@ public void clusterStatePublished(ClusterState newClusterState) {
258260 assertThat (registeredActions .toString (), registeredActions , contains (MasterService .STATE_UPDATE_ACTION_NAME ));
259261 }
260262
261- public void testThreadContext () throws InterruptedException {
263+ public void testThreadContext () {
262264 try (var master = createMasterService (true )) {
263- final CountDownLatch latch = new CountDownLatch (1 );
265+
266+ master .setClusterStatePublisher ((clusterStatePublicationEvent , publishListener , ackListener ) -> {
267+ ClusterServiceUtils .setAllElapsedMillis (clusterStatePublicationEvent );
268+ try (var ignored = threadPool .getThreadContext ().newEmptyContext ()) {
269+ if (randomBoolean ()) {
270+ randomExecutor (threadPool ).execute (() -> publishListener .onResponse (null ));
271+ randomExecutor (threadPool ).execute (() -> ackListener .onCommit (TimeValue .timeValueMillis (randomInt (10000 ))));
272+ randomExecutor (threadPool ).execute (
273+ () -> ackListener .onNodeAck (
274+ clusterStatePublicationEvent .getNewState ().nodes ().getMasterNode (),
275+ randomBoolean () ? null : new RuntimeException ("simulated ack failure" )
276+ )
277+ );
278+ } else {
279+ randomExecutor (threadPool ).execute (
280+ () -> publishListener .onFailure (new FailedToCommitClusterStateException ("simulated publish failure" ))
281+ );
282+ }
283+ }
284+ });
285+
286+ final Releasable onPublishComplete ;
287+ final Releasable onAckingComplete ;
288+ final Runnable awaitComplete ;
289+ {
290+ final var publishLatch = new CountDownLatch (1 );
291+ final var ackingLatch = new CountDownLatch (1 );
292+ onPublishComplete = Releasables .assertOnce (publishLatch ::countDown );
293+ onAckingComplete = Releasables .assertOnce (ackingLatch ::countDown );
294+ awaitComplete = () -> {
295+ safeAwait (publishLatch );
296+ safeAwait (ackingLatch );
297+ };
298+ }
264299
265300 try (ThreadContext .StoredContext ignored = threadPool .getThreadContext ().stashContext ()) {
266301
@@ -271,15 +306,12 @@ public void testThreadContext() throws InterruptedException {
271306 expectedHeaders .put (copiedHeader , randomIdentifier ());
272307 }
273308 }
274-
275- final Map <String , List <String >> expectedResponseHeaders = Collections .singletonMap (
276- "testResponse" ,
277- Collections .singletonList ("testResponse" )
278- );
279309 threadPool .getThreadContext ().putHeader (expectedHeaders );
280310
281- final TimeValue ackTimeout = randomBoolean () ? TimeValue .ZERO : TimeValue .timeValueMillis (randomInt (10000 ));
282- final TimeValue masterTimeout = randomBoolean () ? TimeValue .ZERO : TimeValue .timeValueMillis (randomInt (10000 ));
311+ final Map <String , List <String >> expectedResponseHeaders = Map .of ("testResponse" , List .of (randomIdentifier ()));
312+
313+ final TimeValue ackTimeout = randomBoolean () ? TimeValue .MINUS_ONE : TimeValue .timeValueMillis (randomInt (10000 ));
314+ final TimeValue masterTimeout = randomBoolean () ? TimeValue .MINUS_ONE : TimeValue .timeValueMillis (randomInt (10000 ));
283315
284316 master .submitUnbatchedStateUpdateTask (
285317 "test" ,
@@ -288,8 +320,9 @@ public void testThreadContext() throws InterruptedException {
288320 public ClusterState execute (ClusterState currentState ) {
289321 assertTrue (threadPool .getThreadContext ().isSystemContext ());
290322 assertEquals (Collections .emptyMap (), threadPool .getThreadContext ().getHeaders ());
291- threadPool .getThreadContext ().addResponseHeader ("testResponse" , "testResponse" );
292- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
323+ expectedResponseHeaders .forEach (
324+ (name , values ) -> values .forEach (v -> threadPool .getThreadContext ().addResponseHeader (name , v ))
325+ );
293326
294327 if (randomBoolean ()) {
295328 return ClusterState .builder (currentState ).build ();
@@ -302,44 +335,44 @@ public ClusterState execute(ClusterState currentState) {
302335
303336 @ Override
304337 public void onFailure (Exception e ) {
305- assertFalse (threadPool .getThreadContext ().isSystemContext ());
306- assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
307- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
308- latch .countDown ();
338+ assertExpectedThreadContext (
339+ e instanceof ProcessClusterEventTimeoutException ? Map .of () : expectedResponseHeaders
340+ );
341+ onPublishComplete .close ();
342+ onAckingComplete .close (); // no acking takes place if publication failed
309343 }
310344
311345 @ Override
312346 public void clusterStateProcessed (ClusterState oldState , ClusterState newState ) {
313- assertFalse (threadPool .getThreadContext ().isSystemContext ());
314- assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
315- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
316- latch .countDown ();
347+ assertExpectedThreadContext (expectedResponseHeaders );
348+ onPublishComplete .close ();
317349 }
318350
319351 @ Override
320352 public void onAllNodesAcked () {
321- assertFalse (threadPool .getThreadContext ().isSystemContext ());
322- assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
323- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
324- latch .countDown ();
353+ onAckCompletion ();
325354 }
326355
327356 @ Override
328357 public void onAckFailure (Exception e ) {
329- assertFalse (threadPool .getThreadContext ().isSystemContext ());
330- assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
331- assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
332- latch .countDown ();
358+ onAckCompletion ();
333359 }
334360
335361 @ Override
336362 public void onAckTimeout () {
363+ onAckCompletion ();
364+ }
365+
366+ private void onAckCompletion () {
367+ assertExpectedThreadContext (expectedResponseHeaders );
368+ onAckingComplete .close ();
369+ }
370+
371+ private void assertExpectedThreadContext (Map <String , List <String >> expectedResponseHeaders ) {
337372 assertFalse (threadPool .getThreadContext ().isSystemContext ());
338373 assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
339374 assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
340- latch .countDown ();
341375 }
342-
343376 }
344377 );
345378
@@ -348,7 +381,7 @@ public void onAckTimeout() {
348381 assertEquals (Collections .emptyMap (), threadPool .getThreadContext ().getResponseHeaders ());
349382 }
350383
351- assertTrue ( latch . await ( 10 , TimeUnit . SECONDS ) );
384+ awaitComplete . run ( );
352385 }
353386 }
354387
0 commit comments