4545import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
4646import org.elasticsearch.common.util.concurrent.StoppableExecutorServiceWrapper;
4747import org.elasticsearch.common.util.concurrent.ThreadContext;
48+ import org.elasticsearch.core.Releasable;
49+ import org.elasticsearch.core.Releasables;
4850import org.elasticsearch.core.SuppressForbidden;
4951import org.elasticsearch.core.TimeValue;
5052import org.elasticsearch.core.Tuple;
@@ -258,9 +260,42 @@ public void clusterStatePublished(ClusterState newClusterState) {
258260 assertThat(registeredActions.toString(), registeredActions, contains(MasterService.STATE_UPDATE_ACTION_NAME));
259261 }
260262
261- public void testThreadContext() throws InterruptedException {
263+ public void testThreadContext() {
262264 try (var master = createMasterService(true)) {
263- final CountDownLatch latch = new CountDownLatch(1);
265+
266+ master.setClusterStatePublisher((clusterStatePublicationEvent, publishListener, ackListener) -> {
267+ ClusterServiceUtils.setAllElapsedMillis(clusterStatePublicationEvent);
268+ try (var ignored = threadPool.getThreadContext().newEmptyContext()) {
269+ if (randomBoolean()) {
270+ randomExecutor(threadPool).execute(() -> publishListener.onResponse(null));
271+ randomExecutor(threadPool).execute(() -> ackListener.onCommit(TimeValue.timeValueMillis(randomInt(10000))));
272+ randomExecutor(threadPool).execute(
273+ () -> ackListener.onNodeAck(
274+ clusterStatePublicationEvent.getNewState().nodes().getMasterNode(),
275+ randomBoolean() ? null : new RuntimeException("simulated ack failure")
276+ )
277+ );
278+ } else {
279+ randomExecutor(threadPool).execute(
280+ () -> publishListener.onFailure(new FailedToCommitClusterStateException("simulated publish failure"))
281+ );
282+ }
283+ }
284+ });
285+
286+ final Releasable onPublishComplete;
287+ final Releasable onAckingComplete;
288+ final Runnable awaitComplete;
289+ {
290+ final var publishLatch = new CountDownLatch(1);
291+ final var ackingLatch = new CountDownLatch(1);
292+ onPublishComplete = Releasables.assertOnce(publishLatch::countDown);
293+ onAckingComplete = Releasables.assertOnce(ackingLatch::countDown);
294+ awaitComplete = () -> {
295+ safeAwait(publishLatch);
296+ safeAwait(ackingLatch);
297+ };
298+ }
264299
265300 try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) {
266301
@@ -271,15 +306,12 @@ public void testThreadContext() throws InterruptedException {
271306 expectedHeaders.put(copiedHeader, randomIdentifier());
272307 }
273308 }
274-
275- final Map<String, List<String>> expectedResponseHeaders = Collections.singletonMap(
276- "testResponse",
277- Collections.singletonList("testResponse")
278- );
279309 threadPool.getThreadContext().putHeader(expectedHeaders);
280310
281- final TimeValue ackTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));
282- final TimeValue masterTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));
311+ final Map<String, List<String>> expectedResponseHeaders = Map.of("testResponse", List.of(randomIdentifier()));
312+
313+ final TimeValue ackTimeout = randomBoolean() ? TimeValue.MINUS_ONE : TimeValue.timeValueMillis(randomInt(10000));
314+ final TimeValue masterTimeout = randomBoolean() ? TimeValue.MINUS_ONE : TimeValue.timeValueMillis(randomInt(10000));
283315
284316 master.submitUnbatchedStateUpdateTask(
285317 "test",
@@ -288,8 +320,9 @@ public void testThreadContext() throws InterruptedException {
288320 public ClusterState execute(ClusterState currentState) {
289321 assertTrue(threadPool.getThreadContext().isSystemContext());
290322 assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders());
291- threadPool.getThreadContext().addResponseHeader("testResponse", "testResponse");
292- assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
323+ expectedResponseHeaders.forEach(
324+ (name, values) -> values.forEach(v -> threadPool.getThreadContext().addResponseHeader(name, v))
325+ );
293326
294327 if (randomBoolean()) {
295328 return ClusterState.builder(currentState).build();
@@ -302,44 +335,44 @@ public ClusterState execute(ClusterState currentState) {
302335
303336 @Override
304337 public void onFailure(Exception e) {
305- assertFalse(threadPool.getThreadContext().isSystemContext());
306- assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
307- assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
308- latch.countDown();
338+ assertExpectedThreadContext(
339+ e instanceof ProcessClusterEventTimeoutException ? Map.of() : expectedResponseHeaders
340+ );
341+ onPublishComplete.close();
342+ onAckingComplete.close(); // no acking takes place if publication failed
309343 }
310344
311345 @Override
312346 public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
313- assertFalse(threadPool.getThreadContext().isSystemContext());
314- assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
315- assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
316- latch.countDown();
347+ assertExpectedThreadContext(expectedResponseHeaders);
348+ onPublishComplete.close();
317349 }
318350
319351 @Override
320352 public void onAllNodesAcked() {
321- assertFalse(threadPool.getThreadContext().isSystemContext());
322- assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
323- assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
324- latch.countDown();
353+ onAckCompletion();
325354 }
326355
327356 @Override
328357 public void onAckFailure(Exception e) {
329- assertFalse(threadPool.getThreadContext().isSystemContext());
330- assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
331- assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
332- latch.countDown();
358+ onAckCompletion();
333359 }
334360
335361 @Override
336362 public void onAckTimeout() {
363+ onAckCompletion();
364+ }
365+
366+ private void onAckCompletion() {
367+ assertExpectedThreadContext(expectedResponseHeaders);
368+ onAckingComplete.close();
369+ }
370+
371+ private void assertExpectedThreadContext(Map<String, List<String>> expectedResponseHeaders) {
337372 assertFalse(threadPool.getThreadContext().isSystemContext());
338373 assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
339374 assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
340- latch.countDown();
341375 }
342-
343376 }
344377 );
345378
@@ -348,7 +381,7 @@ public void onAckTimeout() {
348381 assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getResponseHeaders());
349382 }
350383
351- assertTrue(latch.await(10, TimeUnit.SECONDS) );
384+ awaitComplete.run( );
352385 }
353386 }
354387
0 commit comments