Skip to content

Commit 01edab5

Browse files
Fix NPE caused by race condition in async search when minimise round trips is true (#117504)
* Fix NPE caused by race condition in async search when minimise round trips is true Previously, the `notifyListShards()` initialised and updated the required pre-requisites (`searchResponse` being amongst them) when a search op began. This function takes in arguments that contain shard-specific details amongst others. Because this information is not immediately available when the search begins, it is not immediately called. In some specific cases, there can be a race condition that can cause the pre-requisities (such as `searchResponse`) to be accessed before they're initialised, causing an NPE. This fix addresses the race condition by splitting the initialisation and subsequent updation amongst 2 different methods. This way, the pre-requisities are always initialised and do not lead to an NPE. * Try: call `notifyListShards()` after `notifySearchStart()` when minimize round trips is true * Add removed code comment * Pass `Clusters` to `SearchTask` rather than using progress listener to signify search start. To prevent polluting the progress listener with unnecessary search specific details, we now pass the `Clusters` object to `SearchTask` when a search op begins. This lets `AsyncSearchTask` access it and use it to initialise `MutableSearchResponse` appropriately. * Use appropriate `clusters` object rather than re-building it * Do not double set `mutableSearchResponse` * Move mutable entities such as shard counts out of `MutableSearchResponse` * Address PR review: revert moving out mutable entities from `MutableSearchResponse` * Update docs/changelog/117504.yaml * Get rid of `SetOnce` for `searchResponse` * Drop redundant check around shards count * Add a test that calls `onListShards()` at last and clarify `updateShardsAndClusters()`'s comment * Fix test: ref count * Address review comment: rewrite comment and test
1 parent a4482d4 commit 01edab5

File tree

4 files changed

+87
-38
lines changed

4 files changed

+87
-38
lines changed

docs/changelog/117504.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 117504
2+
summary: Fix NPE caused by race condition in async search when minimise round trips
3+
is true
4+
area: Search
5+
type: bug
6+
issues: []

x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
package org.elasticsearch.xpack.search;
88

99
import org.apache.lucene.search.TotalHits;
10-
import org.apache.lucene.util.SetOnce;
1110
import org.elasticsearch.ElasticsearchException;
1211
import org.elasticsearch.ElasticsearchStatusException;
1312
import org.elasticsearch.ExceptionsHelper;
@@ -73,7 +72,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
7372
private volatile long expirationTimeMillis;
7473
private final AtomicBoolean isCancelling = new AtomicBoolean(false);
7574

76-
private final SetOnce<MutableSearchResponse> searchResponse = new SetOnce<>();
75+
private final MutableSearchResponse searchResponse;
7776

7877
/**
7978
* Creates an instance of {@link AsyncSearchTask}.
@@ -112,6 +111,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
112111
this.aggReduceContextSupplier = aggReduceContextSupplierFactory.apply(this::isCancelled);
113112
this.progressListener = new Listener();
114113
setProgressListener(progressListener);
114+
searchResponse = new MutableSearchResponse(threadPool.getThreadContext());
115115
}
116116

117117
/**
@@ -340,7 +340,7 @@ private AsyncSearchResponse getResponseWithHeaders() {
340340
}
341341

342342
private AsyncSearchResponse getResponse(boolean restoreResponseHeaders) {
343-
MutableSearchResponse mutableSearchResponse = searchResponse.get();
343+
MutableSearchResponse mutableSearchResponse = searchResponse;
344344
assert mutableSearchResponse != null;
345345
checkCancellation();
346346
AsyncSearchResponse asyncSearchResponse;
@@ -370,7 +370,7 @@ private synchronized void checkCancellation() {
370370
* Returns the status from {@link AsyncSearchTask}
371371
*/
372372
public static AsyncStatusResponse getStatusResponse(AsyncSearchTask asyncTask) {
373-
MutableSearchResponse mutableSearchResponse = asyncTask.searchResponse.get();
373+
MutableSearchResponse mutableSearchResponse = asyncTask.searchResponse;
374374
assert mutableSearchResponse != null;
375375
return mutableSearchResponse.toStatusResponse(
376376
asyncTask.searchId.getEncoded(),
@@ -381,7 +381,7 @@ public static AsyncStatusResponse getStatusResponse(AsyncSearchTask asyncTask) {
381381

382382
@Override
383383
public void close() {
384-
Releasables.close(searchResponse.get());
384+
Releasables.close(searchResponse);
385385
}
386386

387387
class Listener extends SearchProgressActionListener {
@@ -420,12 +420,11 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc
420420
if (delegate != null) {
421421
delegate.onQueryFailure(shardIndex, shardTarget, exc);
422422
}
423-
searchResponse.get()
424-
.addQueryFailure(
425-
shardIndex,
426-
// the nodeId is null if all replicas of this shard failed
427-
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null)
428-
);
423+
searchResponse.addQueryFailure(
424+
shardIndex,
425+
// the nodeId is null if all replicas of this shard failed
426+
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null)
427+
);
429428
}
430429

431430
@Override
@@ -467,9 +466,7 @@ protected void onListShards(
467466
delegate = new CCSSingleCoordinatorSearchProgressListener();
468467
delegate.onListShards(shards, skipped, clusters, fetchPhase, timeProvider);
469468
}
470-
searchResponse.set(
471-
new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, threadPool.getThreadContext())
472-
);
469+
searchResponse.updateShardsAndClusters(shards.size() + skipped.size(), skipped.size(), clusters);
473470
executeInitListeners();
474471
}
475472

@@ -496,7 +493,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
496493
*/
497494
reducedAggs = () -> InternalAggregations.topLevelReduce(singletonList(aggregations), aggReduceContextSupplier.get());
498495
}
499-
searchResponse.get().updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase);
496+
searchResponse.updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase);
500497
}
501498

502499
/**
@@ -510,7 +507,7 @@ public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, Interna
510507
if (delegate != null) {
511508
delegate.onFinalReduce(shards, totalHits, aggregations, reducePhase);
512509
}
513-
searchResponse.get().updatePartialResponse(shards.size(), totalHits, () -> aggregations, reducePhase);
510+
searchResponse.updatePartialResponse(shards.size(), totalHits, () -> aggregations, reducePhase);
514511
}
515512

516513
/**
@@ -523,24 +520,20 @@ public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, Interna
523520
@Override
524521
public void onClusterResponseMinimizeRoundtrips(String clusterAlias, SearchResponse clusterResponse) {
525522
// no need to call the delegate progress listener, since this method is only called for minimize_roundtrips=true
526-
searchResponse.get().updateResponseMinimizeRoundtrips(clusterAlias, clusterResponse);
523+
searchResponse.updateResponseMinimizeRoundtrips(clusterAlias, clusterResponse);
527524
}
528525

529526
@Override
530527
public void onResponse(SearchResponse response) {
531-
searchResponse.get().updateFinalResponse(response, ccsMinimizeRoundtrips);
528+
searchResponse.updateFinalResponse(response, ccsMinimizeRoundtrips);
532529
executeCompletionListeners();
533530
}
534531

535532
@Override
536533
public void onFailure(Exception exc) {
537-
// if the failure occurred before calling onListShards
538-
var r = new MutableSearchResponse(-1, -1, null, threadPool.getThreadContext());
539-
if (searchResponse.trySet(r) == false) {
540-
r.close();
541-
}
542-
searchResponse.get()
543-
.updateWithFailure(new ElasticsearchStatusException("error while executing search", ExceptionsHelper.status(exc), exc));
534+
searchResponse.updateWithFailure(
535+
new ElasticsearchStatusException("error while executing search", ExceptionsHelper.status(exc), exc)
536+
);
544537
executeInitListeners();
545538
executeCompletionListeners();
546539
}

x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
* run concurrently to 1 and ensures that we pause the search progress when an {@link AsyncSearchResponse} is built.
4040
*/
4141
class MutableSearchResponse implements Releasable {
42-
private final int totalShards;
43-
private final int skippedShards;
44-
private final Clusters clusters;
45-
private final AtomicArray<ShardSearchFailure> queryFailures;
42+
private int totalShards;
43+
private int skippedShards;
44+
private Clusters clusters;
45+
private AtomicArray<ShardSearchFailure> queryFailures;
4646
private final ThreadContext threadContext;
4747

4848
private boolean isPartial;
@@ -82,23 +82,31 @@ class MutableSearchResponse implements Releasable {
8282
/**
8383
* Creates a new mutable search response.
8484
*
85-
* @param totalShards The number of shards that participate in the request, or -1 to indicate a failure.
86-
* @param skippedShards The number of skipped shards, or -1 to indicate a failure.
87-
* @param clusters The remote clusters statistics.
8885
* @param threadContext The thread context to retrieve the final response headers.
8986
*/
90-
MutableSearchResponse(int totalShards, int skippedShards, Clusters clusters, ThreadContext threadContext) {
91-
this.totalShards = totalShards;
92-
this.skippedShards = skippedShards;
93-
94-
this.clusters = clusters;
95-
this.queryFailures = totalShards == -1 ? null : new AtomicArray<>(totalShards - skippedShards);
87+
MutableSearchResponse(ThreadContext threadContext) {
9688
this.isPartial = true;
9789
this.threadContext = threadContext;
9890
this.totalHits = Lucene.TOTAL_HITS_GREATER_OR_EQUAL_TO_ZERO;
9991
this.localClusterComplete = false;
10092
}
10193

94+
/**
95+
* Updates the response with the number of total and skipped shards.
96+
*
97+
* @param totalShards The number of shards that participate in the request.
98+
* @param skippedShards The number of shards skipped.
99+
* <p>
100+
* Shards in this context depend on the value of minimize round trips (MRT):
101+
* They are the shards being searched by this coordinator (local only for MRT=true, local + remote otherwise).
102+
*/
103+
synchronized void updateShardsAndClusters(int totalShards, int skippedShards, Clusters clusters) {
104+
this.totalShards = totalShards;
105+
this.skippedShards = skippedShards;
106+
this.queryFailures = new AtomicArray<>(totalShards - skippedShards);
107+
this.clusters = clusters;
108+
}
109+
102110
/**
103111
* Updates the response with the result of a partial reduction.
104112
*

x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.elasticsearch.test.client.NoOpClient;
3434
import org.elasticsearch.threadpool.TestThreadPool;
3535
import org.elasticsearch.threadpool.ThreadPool;
36+
import org.elasticsearch.transport.RemoteClusterAware;
3637
import org.elasticsearch.xpack.core.async.AsyncExecutionId;
3738
import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
3839
import org.junit.After;
@@ -424,6 +425,47 @@ public void onFailure(Exception e) {
424425
assertThat(failure.get(), instanceOf(RuntimeException.class));
425426
}
426427

428+
public void testDelayedOnListShardsShouldNotResultInExceptions() throws InterruptedException {
429+
try (AsyncSearchTask task = createAsyncSearchTask()) {
430+
int numShards = randomIntBetween(0, 10);
431+
List<SearchShard> shards = new ArrayList<>();
432+
433+
// All local shards.
434+
for (int i = 0; i < numShards; i++) {
435+
shards.add(new SearchShard(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, new ShardId("0", "0", 1)));
436+
}
437+
438+
int numSkippedShards = randomIntBetween(0, 10);
439+
List<SearchShard> skippedShards = new ArrayList<>();
440+
for (int i = 0; i < numSkippedShards; i++) {
441+
skippedShards.add(new SearchShard(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, new ShardId("0", "0", 1)));
442+
}
443+
444+
int totalShards = numShards + numSkippedShards;
445+
for (int i = 0; i < numShards; i++) {
446+
task.getSearchProgressActionListener()
447+
.onPartialReduce(shards.subList(i, i + 1), new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
448+
}
449+
450+
task.getSearchProgressActionListener()
451+
.onFinalReduce(shards, new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
452+
453+
SearchResponse searchResponse = newSearchResponse(totalShards, totalShards, numSkippedShards);
454+
task.getSearchProgressActionListener()
455+
.onClusterResponseMinimizeRoundtrips(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, searchResponse);
456+
457+
/**
458+
* We're calling onListShards() at last. Previously, this delay would have resulted in an NPE for other `onABC()` methods.
459+
* Now, we should not see any Exceptions or errors (be it NPE or anything else).
460+
*/
461+
task.getSearchProgressActionListener()
462+
.onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
463+
464+
ActionListener.respondAndRelease((AsyncSearchTask.Listener) task.getProgressListener(), searchResponse);
465+
assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, false, false);
466+
}
467+
}
468+
427469
private static SearchResponse newSearchResponse(
428470
int totalShards,
429471
int successfulShards,

0 commit comments

Comments
 (0)