Skip to content

Commit 7d11e41

Browse files
arteamDaveCTurner
andauthored
[7.17] Preserve context in ResultDeduplicator (#84038) (#96868)
Today the `ResultDeduplicator` may complete a collection of listeners in contexts different from the ones in which they were submitted. This commit makes sure that the context is preserved in the listener. Co-authored-by: David Turner <[email protected]>
1 parent 9bb69a2 commit 7d11e41

File tree

7 files changed

+54
-20
lines changed

7 files changed

+54
-20
lines changed

docs/changelog/84038.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 84038
2+
summary: Preserve context in `ResultDeduplicator`
3+
area: Infra/Core
4+
type: bug
5+
issues:
6+
- 84036

server/src/main/java/org/elasticsearch/action/ResultDeduplicator.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
package org.elasticsearch.action;
1010

11+
import org.elasticsearch.action.support.ContextPreservingActionListener;
1112
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
13+
import org.elasticsearch.common.util.concurrent.ThreadContext;
1214

1315
import java.util.ArrayList;
1416
import java.util.List;
@@ -22,8 +24,13 @@
2224
*/
2325
public final class ResultDeduplicator<T, R> {
2426

27+
private final ThreadContext threadContext;
2528
private final ConcurrentMap<T, CompositeListener> requests = ConcurrentCollections.newConcurrentMap();
2629

30+
public ResultDeduplicator(ThreadContext threadContext) {
31+
this.threadContext = threadContext;
32+
}
33+
2734
/**
2835
* Ensures a given request not executed multiple times when another equal request is already in-flight.
2936
* If the request is not yet known to the deduplicator it will invoke the passed callback with an {@link ActionListener}
@@ -35,7 +42,8 @@ public final class ResultDeduplicator<T, R> {
3542
* @param callback Callback to be invoked with request and completion listener the first time the request is added to the deduplicator
3643
*/
3744
public void executeOnce(T request, ActionListener<R> listener, BiConsumer<T, ActionListener<R>> callback) {
38-
ActionListener<R> completionListener = requests.computeIfAbsent(request, CompositeListener::new).addListener(listener);
45+
ActionListener<R> completionListener = requests.computeIfAbsent(request, CompositeListener::new)
46+
.addListener(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext));
3947
if (completionListener != null) {
4048
callback.accept(request, completionListener);
4149
}

server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ private static Priority parseReroutePriority(String priorityString) {
118118

119119
// a list of shards that failed during replication
120120
// we keep track of these shards in order to avoid sending duplicate failed shard requests for a single failing shard.
121-
private final ResultDeduplicator<FailedShardEntry, Void> remoteFailedShardsDeduplicator = new ResultDeduplicator<>();
121+
private final ResultDeduplicator<FailedShardEntry, Void> remoteFailedShardsDeduplicator;
122122

123123
@Inject
124124
public ShardStateAction(
@@ -131,6 +131,7 @@ public ShardStateAction(
131131
this.transportService = transportService;
132132
this.clusterService = clusterService;
133133
this.threadPool = threadPool;
134+
remoteFailedShardsDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext());
134135

135136
followUpRerouteTaskPriority = FOLLOW_UP_REROUTE_PRIORITY_SETTING.get(clusterService.getSettings());
136137
clusterService.getClusterSettings()

server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
8585
private final Map<Snapshot, Map<ShardId, IndexShardSnapshotStatus>> shardSnapshots = new HashMap<>();
8686

8787
// A map of snapshots to the shardIds that we already reported to the master as failed
88-
private final ResultDeduplicator<UpdateIndexShardSnapshotStatusRequest, Void> remoteFailedRequestDeduplicator =
89-
new ResultDeduplicator<>();
88+
private final ResultDeduplicator<UpdateIndexShardSnapshotStatusRequest, Void> remoteFailedRequestDeduplicator;
9089

9190
public SnapshotShardsService(
9291
Settings settings,
@@ -100,6 +99,7 @@ public SnapshotShardsService(
10099
this.transportService = transportService;
101100
this.clusterService = clusterService;
102101
this.threadPool = transportService.getThreadPool();
102+
this.remoteFailedRequestDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext());
103103
if (DiscoveryNode.canContainData(settings)) {
104104
// this is only useful on the nodes that can hold data
105105
clusterService.addListener(this);

server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ public class TaskCancellationService {
4444
private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
4545
private final TransportService transportService;
4646
private final TaskManager taskManager;
47-
private final ResultDeduplicator<CancelRequest, Void> deduplicator = new ResultDeduplicator<>();
47+
private final ResultDeduplicator<CancelRequest, Void> deduplicator;
4848

4949
public TaskCancellationService(TransportService transportService) {
5050
this.transportService = transportService;
5151
this.taskManager = transportService.getTaskManager();
52+
this.deduplicator = new ResultDeduplicator<>(transportService.getThreadPool().getThreadContext());
5253
transportService.registerRequestHandler(
5354
BAN_PARENT_ACTION_NAME,
5455
ThreadPool.Names.SAME,

server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import static org.hamcrest.Matchers.everyItem;
4747
import static org.hamcrest.Matchers.in;
4848
import static org.mockito.Mockito.mock;
49+
import static org.mockito.Mockito.when;
4950

5051
public class TaskManagerTests extends ESTestCase {
5152
private ThreadPool threadPool;
@@ -76,7 +77,9 @@ public void testResultsServiceRetryTotalTime() {
7677
public void testTrackingChannelTask() throws Exception {
7778
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
7879
Set<Task> cancelledTasks = ConcurrentCollections.newConcurrentSet();
79-
taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) {
80+
final TransportService transportServiceMock = mock(TransportService.class);
81+
when(transportServiceMock.getThreadPool()).thenReturn(threadPool);
82+
taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) {
8083
@Override
8184
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
8285
assertThat(reason, equalTo("channel was closed"));
@@ -124,7 +127,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
124127
public void testTrackingTaskAndCloseChannelConcurrently() throws Exception {
125128
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
126129
Set<CancellableTask> cancelledTasks = ConcurrentCollections.newConcurrentSet();
127-
taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) {
130+
final TransportService transportServiceMock = mock(TransportService.class);
131+
when(transportServiceMock.getThreadPool()).thenReturn(threadPool);
132+
taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) {
128133
@Override
129134
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
130135
assertTrue("task [" + task + "] was cancelled already", cancelledTasks.add(task));
@@ -180,7 +185,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
180185

181186
public void testRemoveBansOnChannelDisconnects() throws Exception {
182187
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
183-
taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) {
188+
final TransportService transportServiceMock = mock(TransportService.class);
189+
when(transportServiceMock.getThreadPool()).thenReturn(threadPool);
190+
taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) {
184191
@Override
185192
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {}
186193
});

server/src/test/java/org/elasticsearch/transport/ResultDeduplicatorTests.java

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import org.apache.lucene.util.SetOnce;
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.action.ResultDeduplicator;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.common.util.concurrent.ThreadContext;
1315
import org.elasticsearch.tasks.TaskId;
1416
import org.elasticsearch.test.ESTestCase;
1517

@@ -29,27 +31,36 @@ public void testRequestDeduplication() throws Exception {
2931
@Override
3032
public void setParentTask(final TaskId taskId) {}
3133
};
32-
final ResultDeduplicator<TransportRequest, Void> deduplicator = new ResultDeduplicator<>();
34+
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
35+
final ResultDeduplicator<TransportRequest, Void> deduplicator = new ResultDeduplicator<>(threadContext);
3336
final SetOnce<ActionListener<Void>> listenerHolder = new SetOnce<>();
37+
final String headerName = "thread-context-header";
38+
final AtomicInteger headerGenerator = new AtomicInteger();
3439
int iterationsPerThread = scaledRandomIntBetween(100, 1000);
3540
Thread[] threads = new Thread[between(1, 4)];
3641
Phaser barrier = new Phaser(threads.length + 1);
3742
for (int i = 0; i < threads.length; i++) {
3843
threads[i] = new Thread(() -> {
3944
barrier.arriveAndAwaitAdvance();
4045
for (int n = 0; n < iterationsPerThread; n++) {
41-
deduplicator.executeOnce(request, new ActionListener<Void>() {
42-
@Override
43-
public void onResponse(Void aVoid) {
44-
successCount.incrementAndGet();
45-
}
46+
final String headerValue = Integer.toString(headerGenerator.incrementAndGet());
47+
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
48+
threadContext.putHeader(headerName, headerValue);
49+
deduplicator.executeOnce(request, new ActionListener<Void>() {
50+
@Override
51+
public void onResponse(Void aVoid) {
52+
assertThat(threadContext.getHeader(headerName), equalTo(headerValue));
53+
successCount.incrementAndGet();
54+
}
4655

47-
@Override
48-
public void onFailure(Exception e) {
49-
assertThat(e, sameInstance(failure));
50-
failureCount.incrementAndGet();
51-
}
52-
}, (req, reqListener) -> listenerHolder.set(reqListener));
56+
@Override
57+
public void onFailure(Exception e) {
58+
assertThat(threadContext.getHeader(headerName), equalTo(headerValue));
59+
assertThat(e, sameInstance(failure));
60+
failureCount.incrementAndGet();
61+
}
62+
}, (req, reqListener) -> listenerHolder.set(reqListener));
63+
}
5364
}
5465
});
5566
threads[i].start();

0 commit comments

Comments
 (0)