Skip to content

Commit 9bbd171

Browse files
Simplify and flatten CanMatchPreFilterSearchPhase (#118558) (#124876)
We don't need the result to be a `SearchPhaseResults`. In fact, there is no reason for it to be a nested class in the first place. Just flatten it into the phase itself and synchronize on `this`. This is just a step on the way to #116881 that makes that PR much easier to review I believe.
1 parent 7088cb9 commit 9bbd171

File tree

5 files changed

+103
-161
lines changed

5 files changed

+103
-161
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,6 @@ void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Excepti
533533
successfulOps.decrementAndGet(); // if this shard was successful before (initial phase) we have to adjust the counter
534534
}
535535
}
536-
results.consumeShardFailure(shardIndex);
537536
}
538537

539538
private static boolean isTaskCancelledException(Exception e) {

server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java

Lines changed: 84 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141
import java.util.concurrent.Executor;
4242
import java.util.concurrent.atomic.AtomicReferenceArray;
4343
import java.util.function.BiFunction;
44-
import java.util.stream.Collectors;
45-
import java.util.stream.Stream;
4644

4745
import static org.elasticsearch.core.Strings.format;
4846
import static org.elasticsearch.core.Types.forciblyCast;
@@ -58,7 +56,7 @@
5856
* sort them according to the provided order. This can be useful for instance to ensure that shards that contain recent
5957
* data are executed first when sorting by descending timestamp.
6058
*/
61-
final class CanMatchPreFilterSearchPhase extends SearchPhase {
59+
final class CanMatchPreFilterSearchPhase {
6260

6361
private final Logger logger;
6462
private final SearchRequest request;
@@ -74,7 +72,9 @@ final class CanMatchPreFilterSearchPhase extends SearchPhase {
7472
private final Executor executor;
7573
private final boolean requireAtLeastOneMatch;
7674

77-
private final CanMatchSearchPhaseResults results;
75+
private final FixedBitSet possibleMatches;
76+
private final MinAndMax<?>[] minAndMaxes;
77+
private int numPossibleMatches;
7878
private final CoordinatorRewriteContextProvider coordinatorRewriteContextProvider;
7979

8080
CanMatchPreFilterSearchPhase(
@@ -92,7 +92,6 @@ final class CanMatchPreFilterSearchPhase extends SearchPhase {
9292
CoordinatorRewriteContextProvider coordinatorRewriteContextProvider,
9393
ActionListener<List<SearchShardIterator>> listener
9494
) {
95-
super("can_match");
9695
this.logger = logger;
9796
this.searchTransportService = searchTransportService;
9897
this.nodeIdToConnection = nodeIdToConnection;
@@ -106,12 +105,13 @@ final class CanMatchPreFilterSearchPhase extends SearchPhase {
106105
this.requireAtLeastOneMatch = requireAtLeastOneMatch;
107106
this.coordinatorRewriteContextProvider = coordinatorRewriteContextProvider;
108107
this.executor = executor;
109-
results = new CanMatchSearchPhaseResults(shardsIts.size());
110-
108+
final int size = shardsIts.size();
109+
possibleMatches = new FixedBitSet(size);
110+
minAndMaxes = new MinAndMax<?>[size];
111111
// we compute the shard index based on the natural order of the shards
112112
// that participate in the search request. This means that this number is
113113
// consistent between two requests that target the same shards.
114-
final SearchShardIterator[] naturalOrder = new SearchShardIterator[shardsIts.size()];
114+
final SearchShardIterator[] naturalOrder = new SearchShardIterator[size];
115115
int i = 0;
116116
for (SearchShardIterator shardsIt : shardsIts) {
117117
naturalOrder[i++] = shardsIt;
@@ -128,21 +128,6 @@ private static boolean assertSearchCoordinationThread() {
128128
return ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION);
129129
}
130130

131-
@Override
132-
public void run() {
133-
assert assertSearchCoordinationThread();
134-
Version version = request.minCompatibleShardNode();
135-
if (version != null && Version.CURRENT.minimumCompatibilityVersion().equals(version) == false) {
136-
if (checkMinimumVersion(shardsIts) == false) {
137-
throw new VersionMismatchException(
138-
"One of the shards is incompatible with the required minimum version [{}]",
139-
request.minCompatibleShardNode()
140-
);
141-
}
142-
}
143-
runCoordinatorRewritePhase();
144-
}
145-
146131
// tries to pre-filter shards based on information that's available to the coordinator
147132
// without having to reach out to the actual shards
148133
private void runCoordinatorRewritePhase() {
@@ -154,7 +139,7 @@ private void runCoordinatorRewritePhase() {
154139
request,
155140
searchShardIterator.getOriginalIndices().indicesOptions(),
156141
Collections.emptyList(),
157-
getNumShards(),
142+
shardsIts.size(),
158143
timeProvider.absoluteStartMillis(),
159144
searchShardIterator.getClusterAlias()
160145
);
@@ -192,12 +177,32 @@ private void runCoordinatorRewritePhase() {
192177
private void consumeResult(boolean canMatch, ShardSearchRequest request) {
193178
CanMatchShardResponse result = new CanMatchShardResponse(canMatch, null);
194179
result.setShardIndex(request.shardRequestIndex());
195-
results.consumeResult(result, () -> {});
180+
consumeResult(result, () -> {});
181+
}
182+
183+
private void consumeResult(CanMatchShardResponse result, Runnable next) {
184+
try {
185+
final boolean canMatch = result.canMatch();
186+
final MinAndMax<?> minAndMax = result.estimatedMinAndMax();
187+
if (canMatch || minAndMax != null) {
188+
consumeResult(result.getShardIndex(), canMatch, minAndMax);
189+
}
190+
} finally {
191+
next.run();
192+
}
193+
}
194+
195+
private synchronized void consumeResult(int shardIndex, boolean canMatch, MinAndMax<?> minAndMax) {
196+
if (canMatch) {
197+
possibleMatches.set(shardIndex);
198+
numPossibleMatches++;
199+
}
200+
minAndMaxes[shardIndex] = minAndMax;
196201
}
197202

198203
private void checkNoMissingShards(List<SearchShardIterator> shards) {
199204
assert assertSearchCoordinationThread();
200-
doCheckNoMissingShards(getName(), request, shards);
205+
SearchPhase.doCheckNoMissingShards("can_match", request, shards);
201206
}
202207

203208
private Map<SendingTarget, List<SearchShardIterator>> groupByNode(List<SearchShardIterator> shards) {
@@ -250,32 +255,38 @@ protected void doRun() {
250255
continue;
251256
}
252257

258+
var sendingTarget = entry.getKey();
253259
try {
254-
searchTransportService.sendCanMatch(getConnection(entry.getKey()), canMatchNodeRequest, task, new ActionListener<>() {
255-
@Override
256-
public void onResponse(CanMatchNodeResponse canMatchNodeResponse) {
257-
assert canMatchNodeResponse.getResponses().size() == canMatchNodeRequest.getShardLevelRequests().size();
258-
for (int i = 0; i < canMatchNodeResponse.getResponses().size(); i++) {
259-
CanMatchNodeResponse.ResponseOrFailure response = canMatchNodeResponse.getResponses().get(i);
260-
if (response.getResponse() != null) {
261-
CanMatchShardResponse shardResponse = response.getResponse();
262-
shardResponse.setShardIndex(shardLevelRequests.get(i).getShardRequestIndex());
263-
onOperation(shardResponse.getShardIndex(), shardResponse);
264-
} else {
265-
Exception failure = response.getException();
266-
assert failure != null;
267-
onOperationFailed(shardLevelRequests.get(i).getShardRequestIndex(), failure);
260+
searchTransportService.sendCanMatch(
261+
nodeIdToConnection.apply(sendingTarget.clusterAlias, sendingTarget.nodeId),
262+
canMatchNodeRequest,
263+
task,
264+
new ActionListener<>() {
265+
@Override
266+
public void onResponse(CanMatchNodeResponse canMatchNodeResponse) {
267+
assert canMatchNodeResponse.getResponses().size() == canMatchNodeRequest.getShardLevelRequests().size();
268+
for (int i = 0; i < canMatchNodeResponse.getResponses().size(); i++) {
269+
CanMatchNodeResponse.ResponseOrFailure response = canMatchNodeResponse.getResponses().get(i);
270+
if (response.getResponse() != null) {
271+
CanMatchShardResponse shardResponse = response.getResponse();
272+
shardResponse.setShardIndex(shardLevelRequests.get(i).getShardRequestIndex());
273+
onOperation(shardResponse.getShardIndex(), shardResponse);
274+
} else {
275+
Exception failure = response.getException();
276+
assert failure != null;
277+
onOperationFailed(shardLevelRequests.get(i).getShardRequestIndex(), failure);
278+
}
268279
}
269280
}
270-
}
271281

272-
@Override
273-
public void onFailure(Exception e) {
274-
for (CanMatchNodeRequest.Shard shard : shardLevelRequests) {
275-
onOperationFailed(shard.getShardRequestIndex(), e);
282+
@Override
283+
public void onFailure(Exception e) {
284+
for (CanMatchNodeRequest.Shard shard : shardLevelRequests) {
285+
onOperationFailed(shard.getShardRequestIndex(), e);
286+
}
276287
}
277288
}
278-
});
289+
);
279290
} catch (Exception e) {
280291
for (CanMatchNodeRequest.Shard shard : shardLevelRequests) {
281292
onOperationFailed(shard.getShardRequestIndex(), e);
@@ -286,7 +297,7 @@ public void onFailure(Exception e) {
286297

287298
private void onOperation(int idx, CanMatchShardResponse response) {
288299
failedResponses.set(idx, null);
289-
results.consumeResult(response, () -> {
300+
consumeResult(response, () -> {
290301
if (countDown.countDown()) {
291302
finishRound();
292303
}
@@ -295,7 +306,8 @@ private void onOperation(int idx, CanMatchShardResponse response) {
295306

296307
private void onOperationFailed(int idx, Exception e) {
297308
failedResponses.set(idx, e);
298-
results.consumeShardFailure(idx);
309+
// we have to carry over shard failures in order to account for them in the response.
310+
consumeResult(idx, true, null);
299311
if (countDown.countDown()) {
300312
finishRound();
301313
}
@@ -326,7 +338,7 @@ public boolean isForceExecution() {
326338
@Override
327339
public void onFailure(Exception e) {
328340
if (logger.isDebugEnabled()) {
329-
logger.debug(() -> format("Failed to execute [%s] while running [%s] phase", request, getName()), e);
341+
logger.debug(() -> format("Failed to execute [%s] while running [can_match] phase", request), e);
330342
}
331343
onPhaseFailure("round", e);
332344
}
@@ -336,10 +348,7 @@ private record SendingTarget(@Nullable String clusterAlias, @Nullable String nod
336348

337349
private CanMatchNodeRequest createCanMatchRequest(Map.Entry<SendingTarget, List<SearchShardIterator>> entry) {
338350
final SearchShardIterator first = entry.getValue().get(0);
339-
final List<CanMatchNodeRequest.Shard> shardLevelRequests = entry.getValue()
340-
.stream()
341-
.map(this::buildShardLevelRequest)
342-
.collect(Collectors.toCollection(ArrayList::new));
351+
final List<CanMatchNodeRequest.Shard> shardLevelRequests = entry.getValue().stream().map(this::buildShardLevelRequest).toList();
343352
assert entry.getValue().stream().allMatch(Objects::nonNull);
344353
assert entry.getValue()
345354
.stream()
@@ -349,14 +358,14 @@ private CanMatchNodeRequest createCanMatchRequest(Map.Entry<SendingTarget, List<
349358
request,
350359
first.getOriginalIndices().indicesOptions(),
351360
shardLevelRequests,
352-
getNumShards(),
361+
shardsIts.size(),
353362
timeProvider.absoluteStartMillis(),
354363
first.getClusterAlias()
355364
);
356365
}
357366

358367
private void finishPhase() {
359-
listener.onResponse(getIterator(results, shardsIts));
368+
listener.onResponse(getIterator(shardsIts));
360369
}
361370

362371
private static final float DEFAULT_INDEX_BOOST = 1.0f;
@@ -382,7 +391,7 @@ private boolean checkMinimumVersion(List<SearchShardIterator> shardsIts) {
382391
for (SearchShardIterator it : shardsIts) {
383392
if (it.getTargetNodeIds().isEmpty() == false) {
384393
boolean isCompatible = it.getTargetNodeIds().stream().anyMatch(nodeId -> {
385-
Transport.Connection conn = getConnection(new SendingTarget(it.getClusterAlias(), nodeId));
394+
Transport.Connection conn = nodeIdToConnection.apply(it.getClusterAlias(), nodeId);
386395
return conn == null || conn.getNode().getVersion().onOrAfter(request.minCompatibleShardNode());
387396
});
388397
if (isCompatible == false) {
@@ -393,9 +402,8 @@ private boolean checkMinimumVersion(List<SearchShardIterator> shardsIts) {
393402
return true;
394403
}
395404

396-
@Override
397405
public void start() {
398-
if (getNumShards() == 0) {
406+
if (shardsIts.isEmpty()) {
399407
finishPhase();
400408
return;
401409
}
@@ -404,99 +412,35 @@ public void start() {
404412
@Override
405413
public void onFailure(Exception e) {
406414
if (logger.isDebugEnabled()) {
407-
logger.debug(() -> format("Failed to execute [%s] while running [%s] phase", request, getName()), e);
415+
logger.debug(() -> format("Failed to execute [%s] while running [can_match] phase", request), e);
408416
}
409417
onPhaseFailure("start", e);
410418
}
411419

412420
@Override
413421
protected void doRun() {
414-
CanMatchPreFilterSearchPhase.this.run();
422+
assert assertSearchCoordinationThread();
423+
Version version = request.minCompatibleShardNode();
424+
if (version != null && Version.CURRENT.minimumCompatibilityVersion().equals(version) == false) {
425+
if (checkMinimumVersion(shardsIts) == false) {
426+
throw new VersionMismatchException(
427+
"One of the shards is incompatible with the required minimum version [{}]",
428+
request.minCompatibleShardNode()
429+
);
430+
}
431+
}
432+
runCoordinatorRewritePhase();
415433
}
416434
});
417435
}
418436

419-
public void onPhaseFailure(String msg, Exception cause) {
420-
listener.onFailure(new SearchPhaseExecutionException(getName(), msg, cause, ShardSearchFailure.EMPTY_ARRAY));
421-
}
422-
423-
public Transport.Connection getConnection(SendingTarget sendingTarget) {
424-
Transport.Connection conn = nodeIdToConnection.apply(sendingTarget.clusterAlias, sendingTarget.nodeId);
425-
Version minVersion = request.minCompatibleShardNode();
426-
if (minVersion != null && conn != null && conn.getNode().getVersion().before(minVersion)) {
427-
throw new VersionMismatchException("One of the shards is incompatible with the required minimum version [{}]", minVersion);
428-
}
429-
return conn;
430-
}
431-
432-
private int getNumShards() {
433-
return shardsIts.size();
434-
}
435-
436-
private static final class CanMatchSearchPhaseResults extends SearchPhaseResults<CanMatchShardResponse> {
437-
private final FixedBitSet possibleMatches;
438-
private final MinAndMax<?>[] minAndMaxes;
439-
private int numPossibleMatches;
440-
441-
CanMatchSearchPhaseResults(int size) {
442-
super(size);
443-
possibleMatches = new FixedBitSet(size);
444-
minAndMaxes = new MinAndMax<?>[size];
445-
}
446-
447-
@Override
448-
void consumeResult(CanMatchShardResponse result, Runnable next) {
449-
try {
450-
final boolean canMatch = result.canMatch();
451-
final MinAndMax<?> minAndMax = result.estimatedMinAndMax();
452-
if (canMatch || minAndMax != null) {
453-
consumeResult(result.getShardIndex(), canMatch, minAndMax);
454-
}
455-
} finally {
456-
next.run();
457-
}
458-
}
459-
460-
@Override
461-
boolean hasResult(int shardIndex) {
462-
return false; // unneeded
463-
}
464-
465-
@Override
466-
void consumeShardFailure(int shardIndex) {
467-
// we have to carry over shard failures in order to account for them in the response.
468-
consumeResult(shardIndex, true, null);
469-
}
470-
471-
private synchronized void consumeResult(int shardIndex, boolean canMatch, MinAndMax<?> minAndMax) {
472-
if (canMatch) {
473-
possibleMatches.set(shardIndex);
474-
numPossibleMatches++;
475-
}
476-
minAndMaxes[shardIndex] = minAndMax;
477-
}
478-
479-
synchronized int getNumPossibleMatches() {
480-
return numPossibleMatches;
481-
}
482-
483-
synchronized FixedBitSet getPossibleMatches() {
484-
return possibleMatches;
485-
}
486-
487-
@Override
488-
Stream<CanMatchShardResponse> getSuccessfulResults() {
489-
return Stream.empty();
490-
}
491-
492-
@Override
493-
public void close() {}
437+
private void onPhaseFailure(String msg, Exception cause) {
438+
listener.onFailure(new SearchPhaseExecutionException("can_match", msg, cause, ShardSearchFailure.EMPTY_ARRAY));
494439
}
495440

496-
private List<SearchShardIterator> getIterator(CanMatchSearchPhaseResults results, List<SearchShardIterator> shardsIts) {
497-
FixedBitSet possibleMatches = results.getPossibleMatches();
441+
private synchronized List<SearchShardIterator> getIterator(List<SearchShardIterator> shardsIts) {
498442
// TODO: pick the local shard when possible
499-
if (requireAtLeastOneMatch && results.getNumPossibleMatches() == 0) {
443+
if (requireAtLeastOneMatch && numPossibleMatches == 0) {
500444
// this is a special case where we have no hit but we need to get at least one search response in order
501445
// to produce a valid search result with all the aggs etc.
502446
// Since it's possible that some of the shards that we're skipping are
@@ -523,11 +467,11 @@ private List<SearchShardIterator> getIterator(CanMatchSearchPhaseResults results
523467
iter.skip(true);
524468
}
525469
}
526-
if (shouldSortShards(results.minAndMaxes) == false) {
470+
if (shouldSortShards(minAndMaxes) == false) {
527471
return shardsIts;
528472
}
529473
FieldSortBuilder fieldSort = FieldSortBuilder.getPrimaryFieldSortOrNull(request.source());
530-
return sortShards(shardsIts, results.minAndMaxes, fieldSort.order());
474+
return sortShards(shardsIts, minAndMaxes, fieldSort.order());
531475
}
532476

533477
private static List<SearchShardIterator> sortShards(List<SearchShardIterator> shardsIts, MinAndMax<?>[] minAndMaxes, SortOrder order) {

0 commit comments

Comments
 (0)