Skip to content

Commit 261ad85

Browse files
Move duplicate connection lookup logic to AbstractSearchAsyncAction (#117055)
We found this duplication today when working on batching query phase requests. For batching it would be nice to have the connection already available at a higher level in the AbstractSearchAsyncAction and this is a worthwhile cleanup in general, given how many issues we had around connection lookup recently.
1 parent 3b0d7e0 commit 261ad85

File tree

8 files changed

+33
-51
lines changed

8 files changed

+33
-51
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ protected void performPhaseOnShard(final int shardIndex, final SearchShardIterat
299299
}
300300

301301
private void doPerformPhaseOnShard(int shardIndex, SearchShardIterator shardIt, SearchShardTarget shard, Releasable releasable) {
302-
executePhaseOnShard(shardIt, shard, new SearchActionListener<>(shard, shardIndex) {
302+
var shardListener = new SearchActionListener<Result>(shard, shardIndex) {
303303
@Override
304304
public void innerOnResponse(Result result) {
305305
try {
@@ -315,7 +315,15 @@ public void onFailure(Exception e) {
315315
releasable.close();
316316
onShardFailure(shardIndex, shard, shardIt, e);
317317
}
318-
});
318+
};
319+
final Transport.Connection connection;
320+
try {
321+
connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
322+
} catch (Exception e) {
323+
shardListener.onFailure(e);
324+
return;
325+
}
326+
executePhaseOnShard(shardIt, connection, shardListener);
319327
}
320328

321329
private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) {
@@ -327,12 +335,12 @@ private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) {
327335
/**
328336
* Sends the request to the actual shard.
329337
* @param shardIt the shards iterator
330-
* @param shard the shard routing to send the request for
338+
* @param connection to node that the shard is located on
331339
* @param listener the listener to notify on response
332340
*/
333341
protected abstract void executePhaseOnShard(
334342
SearchShardIterator shardIt,
335-
SearchShardTarget shard,
343+
Transport.Connection connection,
336344
SearchActionListener<Result> listener
337345
);
338346

server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,9 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
8484
@Override
8585
protected void executePhaseOnShard(
8686
final SearchShardIterator shardIt,
87-
final SearchShardTarget shard,
87+
final Transport.Connection connection,
8888
final SearchActionListener<DfsSearchResult> listener
8989
) {
90-
final Transport.Connection connection;
91-
try {
92-
connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
93-
} catch (Exception e) {
94-
listener.onFailure(e);
95-
return;
96-
}
9790
getSearchTransport().sendExecuteDfs(connection, buildShardSearchRequest(shardIt, listener.requestIndex), getTask(), listener);
9891
}
9992

server/src/main/java/org/elasticsearch/action/search/SearchPhase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ protected static void doCheckNoMissingShards(
7979
/**
8080
* Releases shard targets that are not used in the docsIdsToLoad.
8181
*/
82-
protected void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, AbstractSearchAsyncAction<?> context) {
82+
protected static void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, AbstractSearchAsyncAction<?> context) {
8383
// we only release search context that we did not fetch from, if we are not scrolling
8484
// or using a PIT and if it has at least one hit that didn't make it to the global topDocs
8585
if (searchPhaseResult == null) {

server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,9 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
9191

9292
protected void executePhaseOnShard(
9393
final SearchShardIterator shardIt,
94-
final SearchShardTarget shard,
94+
final Transport.Connection connection,
9595
final SearchActionListener<SearchPhaseResult> listener
9696
) {
97-
final Transport.Connection connection;
98-
try {
99-
connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
100-
} catch (Exception e) {
101-
listener.onFailure(e);
102-
return;
103-
}
10497
ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt, listener.requestIndex));
10598
getSearchTransport().sendExecuteQuery(connection, request, getTask(), listener);
10699
}

server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.elasticsearch.rest.RestStatus;
3636
import org.elasticsearch.search.SearchPhaseResult;
3737
import org.elasticsearch.search.SearchService;
38-
import org.elasticsearch.search.SearchShardTarget;
3938
import org.elasticsearch.search.builder.SearchSourceBuilder;
4039
import org.elasticsearch.search.internal.AliasFilter;
4140
import org.elasticsearch.search.internal.ShardSearchContextId;
@@ -252,16 +251,9 @@ protected String missingShardsErrorMessage(StringBuilder missingShards) {
252251
@Override
253252
protected void executePhaseOnShard(
254253
SearchShardIterator shardIt,
255-
SearchShardTarget shard,
254+
Transport.Connection connection,
256255
SearchActionListener<SearchPhaseResult> phaseListener
257256
) {
258-
final Transport.Connection connection;
259-
try {
260-
connection = connectionLookup.apply(shardIt.getClusterAlias(), shard.getNodeId());
261-
} catch (Exception e) {
262-
phaseListener.onFailure(e);
263-
return;
264-
}
265257
transportService.sendChildRequest(
266258
connection,
267259
OPEN_SHARD_READER_CONTEXT_NAME,

server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ protected SearchPhase getNextPhase() {
101101
@Override
102102
protected void executePhaseOnShard(
103103
final SearchShardIterator shardIt,
104-
final SearchShardTarget shard,
104+
final Transport.Connection shard,
105105
final SearchActionListener<SearchPhaseResult> listener
106106
) {}
107107

server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ public void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase> nex
147147
@Override
148148
protected void executePhaseOnShard(
149149
SearchShardIterator shardIt,
150-
SearchShardTarget shard,
150+
Transport.Connection shard,
151151
SearchActionListener<SearchPhaseResult> listener
152152
) {
153153
onShardResult(new SearchPhaseResult() {

server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.elasticsearch.index.shard.ShardId;
2525
import org.elasticsearch.search.SearchHits;
2626
import org.elasticsearch.search.SearchPhaseResult;
27-
import org.elasticsearch.search.SearchShardTarget;
2827
import org.elasticsearch.search.internal.AliasFilter;
2928
import org.elasticsearch.search.internal.ShardSearchContextId;
3029
import org.elasticsearch.test.ESTestCase;
@@ -119,16 +118,15 @@ public void testSkipSearchShards() throws InterruptedException {
119118
@Override
120119
protected void executePhaseOnShard(
121120
SearchShardIterator shardIt,
122-
SearchShardTarget shard,
121+
Transport.Connection connection,
123122
SearchActionListener<TestSearchPhaseResult> listener
124123
) {
125-
seenShard.computeIfAbsent(shard.getShardId(), (i) -> {
124+
seenShard.computeIfAbsent(shardIt.shardId(), (i) -> {
126125
numRequests.incrementAndGet(); // only count this once per replica
127126
return Boolean.TRUE;
128127
});
129128

130129
new Thread(() -> {
131-
Transport.Connection connection = getConnection(null, shard.getNodeId());
132130
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(
133131
new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()),
134132
connection.getNode()
@@ -227,23 +225,22 @@ public void testLimitConcurrentShardRequests() throws InterruptedException {
227225
@Override
228226
protected void executePhaseOnShard(
229227
SearchShardIterator shardIt,
230-
SearchShardTarget shard,
228+
Transport.Connection connection,
231229
SearchActionListener<TestSearchPhaseResult> listener
232230
) {
233-
seenShard.computeIfAbsent(shard.getShardId(), (i) -> {
231+
seenShard.computeIfAbsent(shardIt.shardId(), (i) -> {
234232
numRequests.incrementAndGet(); // only count this once per shard copy
235233
return Boolean.TRUE;
236234
});
237235

238236
new Thread(() -> {
239237
safeAwait(awaitInitialRequests);
240-
Transport.Connection connection = getConnection(null, shard.getNodeId());
241238
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(
242239
new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()),
243240
connection.getNode()
244241
);
245242
try {
246-
if (shardFailures[shard.getShardId().id()]) {
243+
if (shardFailures[shardIt.shardId().id()]) {
247244
listener.onFailure(new RuntimeException());
248245
} else {
249246
listener.onResponse(testSearchPhaseResult);
@@ -340,11 +337,11 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI
340337
@Override
341338
protected void executePhaseOnShard(
342339
SearchShardIterator shardIt,
343-
SearchShardTarget shard,
340+
Transport.Connection connection,
344341
SearchActionListener<TestSearchPhaseResult> listener
345342
) {
346-
assertTrue("shard: " + shard.getShardId() + " has been queried twice", testResponse.queried.add(shard.getShardId()));
347-
Transport.Connection connection = getConnection(null, shard.getNodeId());
343+
var shardId = shardIt.shardId();
344+
assertTrue("shard: " + shardId + " has been queried twice", testResponse.queried.add(shardId));
348345
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(
349346
new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()),
350347
connection.getNode()
@@ -464,13 +461,13 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI
464461
@Override
465462
protected void executePhaseOnShard(
466463
SearchShardIterator shardIt,
467-
SearchShardTarget shard,
464+
Transport.Connection connection,
468465
SearchActionListener<TestSearchPhaseResult> listener
469466
) {
470-
assertTrue("shard: " + shard.getShardId() + " has been queried twice", response.queried.add(shard.getShardId()));
471-
Transport.Connection connection = getConnection(null, shard.getNodeId());
467+
var shardId = shardIt.shardId();
468+
assertTrue("shard: " + shardId + " has been queried twice", response.queried.add(shardId));
472469
final TestSearchPhaseResult testSearchPhaseResult;
473-
if (shard.getShardId().id() == 0) {
470+
if (shardId.id() == 0) {
474471
testSearchPhaseResult = new TestSearchPhaseResult(null, connection.getNode());
475472
} else {
476473
testSearchPhaseResult = new TestSearchPhaseResult(
@@ -573,15 +570,14 @@ public void testAllowPartialResults() throws InterruptedException {
573570
@Override
574571
protected void executePhaseOnShard(
575572
SearchShardIterator shardIt,
576-
SearchShardTarget shard,
573+
Transport.Connection connection,
577574
SearchActionListener<TestSearchPhaseResult> listener
578575
) {
579-
seenShard.computeIfAbsent(shard.getShardId(), (i) -> {
576+
seenShard.computeIfAbsent(shardIt.shardId(), (i) -> {
580577
numRequests.incrementAndGet(); // only count this once per shard copy
581578
return Boolean.TRUE;
582579
});
583580
new Thread(() -> {
584-
Transport.Connection connection = getConnection(null, shard.getNodeId());
585581
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(
586582
new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()),
587583
connection.getNode()
@@ -673,7 +669,7 @@ public void testSkipUnavailableSearchShards() throws InterruptedException {
673669
@Override
674670
protected void executePhaseOnShard(
675671
SearchShardIterator shardIt,
676-
SearchShardTarget shard,
672+
Transport.Connection connection,
677673
SearchActionListener<TestSearchPhaseResult> listener
678674
) {
679675
assert false : "Expected to skip all shards";

0 commit comments

Comments
 (0)