Skip to content

Commit b1b98fc

Browse files
Move duplicate connection lookup logic to AbstractSearchAsyncAction (#117055) (#117611)
We found this duplication today when working on batching query phase requests. For batching it would be nice to have the connection already available at a higher level in the AbstractSearchAsyncAction and this is a worthwhile cleanup in general, given how many issues we had around connection lookup recently.
1 parent 83153e8 commit b1b98fc

File tree

9 files changed

+34
-67
lines changed

9 files changed

+34
-67
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ protected void performPhaseOnShard(final int shardIndex, final SearchShardIterat
324324
}
325325

326326
private void doPerformPhaseOnShard(int shardIndex, SearchShardIterator shardIt, SearchShardTarget shard, Releasable releasable) {
327-
executePhaseOnShard(shardIt, shard, new SearchActionListener<>(shard, shardIndex) {
327+
var shardListener = new SearchActionListener<Result>(shard, shardIndex) {
328328
@Override
329329
public void innerOnResponse(Result result) {
330330
try {
@@ -340,7 +340,15 @@ public void onFailure(Exception e) {
340340
releasable.close();
341341
onShardFailure(shardIndex, shard, shardIt, e);
342342
}
343-
});
343+
};
344+
final Transport.Connection connection;
345+
try {
346+
connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
347+
} catch (Exception e) {
348+
shardListener.onFailure(e);
349+
return;
350+
}
351+
executePhaseOnShard(shardIt, connection, shardListener);
344352
}
345353

346354
private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) {
@@ -352,12 +360,12 @@ private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) {
352360
/**
353361
* Sends the request to the actual shard.
354362
* @param shardIt the shards iterator
355-
* @param shard the shard routing to send the request for
363+
* @param connection to node that the shard is located on
356364
* @param listener the listener to notify on response
357365
*/
358366
protected abstract void executePhaseOnShard(
359367
SearchShardIterator shardIt,
360-
SearchShardTarget shard,
368+
Transport.Connection connection,
361369
SearchActionListener<Result> listener
362370
);
363371

server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,9 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
8484
@Override
8585
protected void executePhaseOnShard(
8686
final SearchShardIterator shardIt,
87-
final SearchShardTarget shard,
87+
final Transport.Connection connection,
8888
final SearchActionListener<DfsSearchResult> listener
8989
) {
90-
final Transport.Connection connection;
91-
try {
92-
connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
93-
} catch (Exception e) {
94-
listener.onFailure(e);
95-
return;
96-
}
9790
getSearchTransport().sendExecuteDfs(connection, buildShardSearchRequest(shardIt, listener.requestIndex), getTask(), listener);
9891
}
9992

server/src/main/java/org/elasticsearch/action/search/SearchPhase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ protected void doCheckNoMissingShards(String phaseName, SearchRequest request, G
7474
/**
7575
* Releases shard targets that are not used in the docsIdsToLoad.
7676
*/
77-
protected void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, AbstractSearchAsyncAction<?> context) {
77+
protected static void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, AbstractSearchAsyncAction<?> context) {
7878
// we only release search context that we did not fetch from, if we are not scrolling
7979
// or using a PIT and if it has at least one hit that didn't make it to the global topDocs
8080
if (searchPhaseResult == null) {

server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,9 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
9191

9292
protected void executePhaseOnShard(
9393
final SearchShardIterator shardIt,
94-
final SearchShardTarget shard,
94+
final Transport.Connection connection,
9595
final SearchActionListener<SearchPhaseResult> listener
9696
) {
97-
final Transport.Connection connection;
98-
try {
99-
connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
100-
} catch (Exception e) {
101-
listener.onFailure(e);
102-
return;
103-
}
10497
ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt, listener.requestIndex));
10598
getSearchTransport().sendExecuteQuery(connection, request, getTask(), listener);
10699
}

server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.elasticsearch.rest.RestStatus;
3636
import org.elasticsearch.search.SearchPhaseResult;
3737
import org.elasticsearch.search.SearchService;
38-
import org.elasticsearch.search.SearchShardTarget;
3938
import org.elasticsearch.search.builder.SearchSourceBuilder;
4039
import org.elasticsearch.search.internal.AliasFilter;
4140
import org.elasticsearch.search.internal.ShardSearchContextId;
@@ -253,16 +252,9 @@ protected String missingShardsErrorMessage(StringBuilder missingShards) {
253252
@Override
254253
protected void executePhaseOnShard(
255254
SearchShardIterator shardIt,
256-
SearchShardTarget shard,
255+
Transport.Connection connection,
257256
SearchActionListener<SearchPhaseResult> phaseListener
258257
) {
259-
final Transport.Connection connection;
260-
try {
261-
connection = connectionLookup.apply(shardIt.getClusterAlias(), shard.getNodeId());
262-
} catch (Exception e) {
263-
phaseListener.onFailure(e);
264-
return;
265-
}
266258
transportService.sendChildRequest(
267259
connection,
268260
OPEN_SHARD_READER_CONTEXT_NAME,

server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ protected SearchPhase getNextPhase() {
101101
@Override
102102
protected void executePhaseOnShard(
103103
final SearchShardIterator shardIt,
104-
final SearchShardTarget shard,
104+
final Transport.Connection shard,
105105
final SearchActionListener<SearchPhaseResult> listener
106106
) {}
107107

server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ public void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase> nex
147147
@Override
148148
protected void executePhaseOnShard(
149149
SearchShardIterator shardIt,
150-
SearchShardTarget shard,
150+
Transport.Connection shard,
151151
SearchActionListener<SearchPhaseResult> listener
152152
) {
153153
onShardResult(new SearchPhaseResult() {

server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.elasticsearch.index.shard.ShardId;
2525
import org.elasticsearch.search.SearchHits;
2626
import org.elasticsearch.search.SearchPhaseResult;
27-
import org.elasticsearch.search.SearchShardTarget;
2827
import org.elasticsearch.search.internal.AliasFilter;
2928
import org.elasticsearch.search.internal.ShardSearchContextId;
3029
import org.elasticsearch.test.ESTestCase;
@@ -119,16 +118,15 @@ public void testSkipSearchShards() throws InterruptedException {
119118
@Override
120119
protected void executePhaseOnShard(
121120
SearchShardIterator shardIt,
122-
SearchShardTarget shard,
121+
Transport.Connection connection,
123122
SearchActionListener<TestSearchPhaseResult> listener
124123
) {
125-
seenShard.computeIfAbsent(shard.getShardId(), (i) -> {
124+
seenShard.computeIfAbsent(shardIt.shardId(), (i) -> {
126125
numRequests.incrementAndGet(); // only count this once per replica
127126
return Boolean.TRUE;
128127
});
129128

130129
new Thread(() -> {
131-
Transport.Connection connection = getConnection(null, shard.getNodeId());
132130
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(
133131
new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()),
134132
connection.getNode()
@@ -227,23 +225,22 @@ public void testLimitConcurrentShardRequests() throws InterruptedException {
227225
@Override
228226
protected void executePhaseOnShard(
229227
SearchShardIterator shardIt,
230-
SearchShardTarget shard,
228+
Transport.Connection connection,
231229
SearchActionListener<TestSearchPhaseResult> listener
232230
) {
233-
seenShard.computeIfAbsent(shard.getShardId(), (i) -> {
231+
seenShard.computeIfAbsent(shardIt.shardId(), (i) -> {
234232
numRequests.incrementAndGet(); // only count this once per shard copy
235233
return Boolean.TRUE;
236234
});
237235

238236
new Thread(() -> {
239237
safeAwait(awaitInitialRequests);
240-
Transport.Connection connection = getConnection(null, shard.getNodeId());
241238
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(
242239
new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()),
243240
connection.getNode()
244241
);
245242
try {
246-
if (shardFailures[shard.getShardId().id()]) {
243+
if (shardFailures[shardIt.shardId().id()]) {
247244
listener.onFailure(new RuntimeException());
248245
} else {
249246
listener.onResponse(testSearchPhaseResult);
@@ -340,11 +337,11 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI
340337
@Override
341338
protected void executePhaseOnShard(
342339
SearchShardIterator shardIt,
343-
SearchShardTarget shard,
340+
Transport.Connection connection,
344341
SearchActionListener<TestSearchPhaseResult> listener
345342
) {
346-
assertTrue("shard: " + shard.getShardId() + " has been queried twice", testResponse.queried.add(shard.getShardId()));
347-
Transport.Connection connection = getConnection(null, shard.getNodeId());
343+
var shardId = shardIt.shardId();
344+
assertTrue("shard: " + shardId + " has been queried twice", testResponse.queried.add(shardId));
348345
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(
349346
new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()),
350347
connection.getNode()
@@ -464,13 +461,13 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI
464461
@Override
465462
protected void executePhaseOnShard(
466463
SearchShardIterator shardIt,
467-
SearchShardTarget shard,
464+
Transport.Connection connection,
468465
SearchActionListener<TestSearchPhaseResult> listener
469466
) {
470-
assertTrue("shard: " + shard.getShardId() + " has been queried twice", response.queried.add(shard.getShardId()));
471-
Transport.Connection connection = getConnection(null, shard.getNodeId());
467+
var shardId = shardIt.shardId();
468+
assertTrue("shard: " + shardId + " has been queried twice", response.queried.add(shardId));
472469
final TestSearchPhaseResult testSearchPhaseResult;
473-
if (shard.getShardId().id() == 0) {
470+
if (shardId.id() == 0) {
474471
testSearchPhaseResult = new TestSearchPhaseResult(null, connection.getNode());
475472
} else {
476473
testSearchPhaseResult = new TestSearchPhaseResult(
@@ -573,15 +570,14 @@ public void testAllowPartialResults() throws InterruptedException {
573570
@Override
574571
protected void executePhaseOnShard(
575572
SearchShardIterator shardIt,
576-
SearchShardTarget shard,
573+
Transport.Connection connection,
577574
SearchActionListener<TestSearchPhaseResult> listener
578575
) {
579-
seenShard.computeIfAbsent(shard.getShardId(), (i) -> {
576+
seenShard.computeIfAbsent(shardIt.shardId(), (i) -> {
580577
numRequests.incrementAndGet(); // only count this once per shard copy
581578
return Boolean.TRUE;
582579
});
583580
new Thread(() -> {
584-
Transport.Connection connection = getConnection(null, shard.getNodeId());
585581
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(
586582
new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()),
587583
connection.getNode()
@@ -673,7 +669,7 @@ public void testSkipUnavailableSearchShards() throws InterruptedException {
673669
@Override
674670
protected void executePhaseOnShard(
675671
SearchShardIterator shardIt,
676-
SearchShardTarget shard,
672+
Transport.Connection connection,
677673
SearchActionListener<TestSearchPhaseResult> listener
678674
) {
679675
assert false : "Expected to skip all shards";

server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.elasticsearch.Version;
1717
import org.elasticsearch.action.ActionListener;
1818
import org.elasticsearch.action.OriginalIndices;
19-
import org.elasticsearch.action.support.PlainActionFuture;
2019
import org.elasticsearch.cluster.ClusterName;
2120
import org.elasticsearch.cluster.ClusterState;
2221
import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -733,21 +732,7 @@ public void run() {
733732
assertThat(phase.totalHits().value, equalTo(2L));
734733
assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO));
735734

736-
SearchShardTarget searchShardTarget = new SearchShardTarget("node3", shardIt.shardId(), null);
737-
final PlainActionFuture<Void> f = new PlainActionFuture<>();
738-
SearchActionListener<SearchPhaseResult> listener = new SearchActionListener<SearchPhaseResult>(searchShardTarget, 0) {
739-
@Override
740-
public void onFailure(Exception e) {
741-
f.onFailure(e);
742-
}
743-
744-
@Override
745-
protected void innerOnResponse(SearchPhaseResult response) {
746-
fail("should not be called");
747-
}
748-
};
749-
action.executePhaseOnShard(shardIt, searchShardTarget, listener);
750-
Exception e = expectThrows(VersionMismatchException.class, f::actionGet);
735+
Exception e = expectThrows(VersionMismatchException.class, () -> action.getConnection(null, "node3"));
751736
assertThat(e.getMessage(), equalTo("One of the shards is incompatible with the required minimum version [" + minVersion + "]"));
752737
}
753738
}

0 commit comments

Comments
 (0)