diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index e42f8127c5e97..68c8d0a5c5bf9 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -22,11 +22,11 @@ import org.elasticsearch.search.CanMatchShardResponse; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.MinAndMax; -import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.Transport; @@ -39,6 +39,7 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.BiFunction; @@ -75,7 +76,10 @@ final class CanMatchPreFilterSearchPhase { private final FixedBitSet possibleMatches; private final MinAndMax[] minAndMaxes; private int numPossibleMatches; - private final CoordinatorRewriteContextProvider coordinatorRewriteContextProvider; + // True if the initiating action to this can_match run is doing batched query phase execution. + // If batched query phase execution is in use, then there is no need to physically send can_match requests to other nodes + // and only the coordinating coordinator can_match logic will run. + private final boolean batchQueryPhase; private CanMatchPreFilterSearchPhase( Logger logger, @@ -89,7 +93,7 @@ private CanMatchPreFilterSearchPhase( TransportSearchAction.SearchTimeProvider timeProvider, SearchTask task, boolean requireAtLeastOneMatch, - CoordinatorRewriteContextProvider coordinatorRewriteContextProvider, + boolean batchQueryPhase, ActionListener> listener ) { this.logger = logger; @@ -103,7 +107,6 @@ private CanMatchPreFilterSearchPhase( this.aliasFilter = aliasFilter; this.task = task; this.requireAtLeastOneMatch = requireAtLeastOneMatch; - this.coordinatorRewriteContextProvider = coordinatorRewriteContextProvider; this.executor = executor; final int size = shardsIts.size(); possibleMatches = new FixedBitSet(size); @@ -122,6 +125,7 @@ private CanMatchPreFilterSearchPhase( shardItIndexMap.put(naturalOrder[j], j); } this.shardItIndexMap = shardItIndexMap; + this.batchQueryPhase = batchQueryPhase; } public static SubscribableListener> execute( @@ -130,17 +134,19 @@ public static SubscribableListener> execute( BiFunction nodeIdToConnection, Map aliasFilter, Map concreteIndexBoosts, - Executor executor, SearchRequest request, List shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, SearchTask task, boolean requireAtLeastOneMatch, - CoordinatorRewriteContextProvider coordinatorRewriteContextProvider + boolean batchQueryPhase, + SearchService searchService ) { + if (shardsIts.isEmpty()) { return SubscribableListener.newSucceeded(List.of()); } + ExecutorService executor = searchTransportService.transportService().getThreadPool().executor(ThreadPool.Names.SEARCH_COORDINATION); final SubscribableListener> listener = new SubscribableListener<>(); // Note that the search is failed when this task is rejected by the executor executor.execute(new AbstractRunnable() { @@ -167,9 +173,9 @@ protected void doRun() { timeProvider, task, requireAtLeastOneMatch, - coordinatorRewriteContextProvider, + batchQueryPhase && searchService.batchQueryPhase(), listener - ).runCoordinatorRewritePhase(); + ).runCoordinatorRewritePhase(searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis)); } }); return listener; @@ -181,7 +187,7 @@ private static boolean assertSearchCoordinationThread() { // tries to pre-filter shards based on information that's available to the coordinator // without having to reach out to the actual shards - private void runCoordinatorRewritePhase() { + private void runCoordinatorRewritePhase(CoordinatorRewriteContextProvider coordinatorRewriteContextProvider) { // TODO: the index filter (i.e, `_index:patten`) should be prefiltered on the coordinator assert assertSearchCoordinationThread(); final List matchedShardLevelRequests = new ArrayList<>(); @@ -304,36 +310,17 @@ protected void doRun() { var sendingTarget = entry.getKey(); try { - searchTransportService.sendCanMatch( - nodeIdToConnection.apply(sendingTarget.clusterAlias, sendingTarget.nodeId), - canMatchNodeRequest, - task, - new ActionListener<>() { - @Override - public void onResponse(CanMatchNodeResponse canMatchNodeResponse) { - assert canMatchNodeResponse.getResponses().size() == canMatchNodeRequest.getShardLevelRequests().size(); - for (int i = 0; i < canMatchNodeResponse.getResponses().size(); i++) { - CanMatchNodeResponse.ResponseOrFailure response = canMatchNodeResponse.getResponses().get(i); - if (response.getResponse() != null) { - CanMatchShardResponse shardResponse = response.getResponse(); - shardResponse.setShardIndex(shardLevelRequests.get(i).getShardRequestIndex()); - onOperation(shardResponse.getShardIndex(), shardResponse); - } else { - Exception failure = response.getException(); - assert failure != null; - onOperationFailed(shardLevelRequests.get(i).getShardRequestIndex(), failure); - } - } - } - - @Override - public void onFailure(Exception e) { - for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { - onOperationFailed(shard.getShardRequestIndex(), e); - } - } + var connection = nodeIdToConnection.apply(sendingTarget.clusterAlias, sendingTarget.nodeId); + if (batchQueryPhase && SearchQueryThenFetchAsyncAction.connectionSupportsBatchedExecution(connection)) { + for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { + final int idx = shard.getShardRequestIndex(); + CanMatchShardResponse shardResponse = new CanMatchShardResponse(true, null); + shardResponse.setShardIndex(idx); + onOperation(idx, shardResponse); } - ); + } else { + bwcSendCanMatchRequest(connection, canMatchNodeRequest, shardLevelRequests); + } } catch (Exception e) { for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { onOperationFailed(shard.getShardRequestIndex(), e); @@ -342,6 +329,38 @@ public void onFailure(Exception e) { } } + private void bwcSendCanMatchRequest( + Transport.Connection connection, + CanMatchNodeRequest canMatchNodeRequest, + List shardLevelRequests + ) { + searchTransportService.sendCanMatch(connection, canMatchNodeRequest, task, new ActionListener<>() { + @Override + public void onResponse(CanMatchNodeResponse canMatchNodeResponse) { + assert canMatchNodeResponse.getResponses().size() == shardLevelRequests.size(); + for (int i = 0; i < canMatchNodeResponse.getResponses().size(); i++) { + CanMatchNodeResponse.ResponseOrFailure response = canMatchNodeResponse.getResponses().get(i); + if (response.getResponse() != null) { + CanMatchShardResponse shardResponse = response.getResponse(); + shardResponse.setShardIndex(shardLevelRequests.get(i).getShardRequestIndex()); + onOperation(shardResponse.getShardIndex(), shardResponse); + } else { + Exception failure = response.getException(); + assert failure != null; + onOperationFailed(shardLevelRequests.get(i).getShardRequestIndex(), failure); + } + } + } + + @Override + public void onFailure(Exception e) { + for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { + onOperationFailed(shard.getShardRequestIndex(), e); + } + } + }); + } + private void onOperation(int idx, CanMatchShardResponse response) { failedResponses.set(idx, null); consumeResult(response); @@ -461,17 +480,23 @@ private synchronized List getIterator(List list = new ArrayList<>(indexTranslation.length); + for (int in : indexTranslation) { + list.add(shardsIts.get(in)); + } + return list; } - private static List sortShards(List shardsIts, MinAndMax[] minAndMaxes, SortOrder order) { + public static > int[] sortShards(List shardsIts, MinAndMax[] minAndMaxes, SearchSourceBuilder source) { int bound = shardsIts.size(); List toSort = new ArrayList<>(bound); for (int i = 0; i < bound; i++) { toSort.add(i); } - Comparator> keyComparator = forciblyCast(MinAndMax.getComparator(order)); + Comparator> keyComparator = forciblyCast( + MinAndMax.getComparator(FieldSortBuilder.getPrimaryFieldSortOrNull(source).order()) + ); toSort.sort((idx1, idx2) -> { int res = keyComparator.compare(minAndMaxes[idx1], minAndMaxes[idx2]); if (res != 0) { @@ -479,11 +504,11 @@ private static List sortShards(List sh } return shardsIts.get(idx1).compareTo(shardsIts.get(idx2)); }); - List list = new ArrayList<>(bound); - for (Integer integer : toSort) { - list.add(shardsIts.get(integer)); + int[] result = new int[bound]; + for (int i = 0; i < bound; i++) { + result[i] = toSort.get(i); } - return list; + return result; } private static boolean shouldSortShards(MinAndMax[] minAndMaxes) { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 39e1c30f658d8..563bd26cdb0bf 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -29,9 +29,11 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ListenableFuture; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.SimpleRefCounted; import org.elasticsearch.core.TimeValue; @@ -46,6 +48,7 @@ import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.sort.MinAndMax; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; @@ -76,8 +79,11 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; +import java.util.function.IntUnaryOperator; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; +import static org.elasticsearch.search.sort.FieldSortBuilder.NAME; +import static org.elasticsearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort; public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { @@ -91,6 +97,7 @@ public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction, Writeable { static ShardToQuery readFrom(StreamInput in) throws IOException { @@ -365,6 +374,11 @@ public void writeTo(StreamOutput out) throws IOException { shardId.writeTo(out); out.writeOptionalWriteable(contextId); } + + @Override + public int compareTo(ShardToQuery o) { + return shardId.compareTo(o.shardId); + } } /** @@ -386,11 +400,13 @@ private static ShardSearchRequest tryRewriteWithUpdatedSortValue( // disable tracking total hits if we already reached the required estimation. if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_ACCURATE && bottomSortCollector.getTotalHits() > trackTotalHitsUpTo) { request.source(request.source().shallowCopy().trackTotalHits(false)); + request.setRunCanMatchInQueryPhase(true); } // set the current best bottom field doc if (bottomSortCollector.getBottomSortValues() != null) { request.setBottomSortValues(bottomSortCollector.getBottomSortValues()); + request.setRunCanMatchInQueryPhase(true); } return request; } @@ -412,7 +428,8 @@ protected void doRun(Map shardIndexMap) { } AbstractSearchAsyncAction.doCheckNoMissingShards(getName(), request, shardsIts); final Map perNodeQueries = new HashMap<>(); - final String localNodeId = searchTransportService.transportService().getLocalNode().getId(); + final var transportService = searchTransportService.transportService(); + final String localNodeId = transportService.getLocalNode().getId(); final int numberOfShardsTotal = shardsIts.size(); for (int i = 0; i < numberOfShardsTotal; i++) { final SearchShardIterator shardRoutings = shardsIts.get(i); @@ -425,30 +442,65 @@ protected void doRun(Map shardIndexMap) { } else { final String nodeId = routing.getNodeId(); // local requests don't need batching as there's no network latency - if (localNodeId.equals(nodeId)) { - performPhaseOnShard(shardIndex, shardRoutings, routing); - } else { - var perNodeRequest = perNodeQueries.computeIfAbsent( - new CanMatchPreFilterSearchPhase.SendingTarget(routing.getClusterAlias(), nodeId), - t -> new NodeQueryRequest(request, numberOfShardsTotal, timeProvider.absoluteStartMillis(), t.clusterAlias()) - ); - final String indexUUID = routing.getShardId().getIndex().getUUID(); - perNodeRequest.shards.add( - new ShardToQuery( - concreteIndexBoosts.getOrDefault(indexUUID, DEFAULT_INDEX_BOOST), - getOriginalIndices(shardIndex).indices(), - shardIndex, - routing.getShardId(), - shardRoutings.getSearchContextId() - ) - ); - var filterForAlias = aliasFilter.getOrDefault(indexUUID, AliasFilter.EMPTY); - if (filterForAlias != AliasFilter.EMPTY) { - perNodeRequest.aliasFilters.putIfAbsent(indexUUID, filterForAlias); - } + var perNodeRequest = perNodeQueries.computeIfAbsent( + new CanMatchPreFilterSearchPhase.SendingTarget(routing.getClusterAlias(), nodeId), + t -> new NodeQueryRequest(request, numberOfShardsTotal, timeProvider.absoluteStartMillis(), t.clusterAlias()) + ); + final String indexUUID = routing.getShardId().getIndex().getUUID(); + perNodeRequest.shards.add( + new ShardToQuery( + concreteIndexBoosts.getOrDefault(indexUUID, DEFAULT_INDEX_BOOST), + getOriginalIndices(shardIndex).indices(), + shardIndex, + routing.getShardId(), + shardRoutings.getSearchContextId() + ) + ); + var filterForAlias = aliasFilter.getOrDefault(indexUUID, AliasFilter.EMPTY); + if (filterForAlias != AliasFilter.EMPTY) { + perNodeRequest.aliasFilters.putIfAbsent(indexUUID, filterForAlias); } } } + final var localTarget = new CanMatchPreFilterSearchPhase.SendingTarget(request.getLocalClusterAlias(), localNodeId); + var localNodeRequest = perNodeQueries.remove(localTarget); + if (localNodeRequest != null) { + transportService.getThreadPool().executor(ThreadPool.Names.SEARCH_COORDINATION).execute(new AbstractRunnable() { + @Override + protected void doRun() { + var shards = localNodeRequest.shards; + if (shards.size() > 1 && hasPrimaryFieldSort(request.source())) { + @SuppressWarnings("rawtypes") + final MinAndMax[] minAndMax = new MinAndMax[shards.size()]; + for (int i = 0; i < minAndMax.length; i++) { + // TODO: refactor to avoid building the search request twice, here and then when actually executing the query + minAndMax[i] = searchService.canMatch(buildShardSearchRequestForLocal(localNodeRequest, shards.get(i))) + .estimatedMinAndMax(); + } + + try { + final int[] indexes = CanMatchPreFilterSearchPhase.sortShards(shards, minAndMax, request.source()); + final ShardToQuery[] orig = shards.toArray(new ShardToQuery[0]); + for (int i = 0; i < indexes.length; i++) { + shards.set(i, orig[indexes[i]]); + } + } catch (Exception e) { + // ignored, field type conflicts will be dealt with in upstream logic + // TODO: we should fail the query here, we're already seeing a field type conflict on the sort field, + // no need to actually execute the queries and go through a lot of work before we inevitably have to + // fail the search + + } + } + executeWithoutBatching(localTarget, localNodeRequest); + } + + @Override + public void onFailure(Exception e) { + SearchQueryThenFetchAsyncAction.this.onPhaseFailure(NAME, "", e); + } + }); + } perNodeQueries.forEach((routing, request) -> { if (request.shards.size() == 1) { executeAsSingleRequest(routing, request.shards.getFirst()); @@ -462,13 +514,16 @@ protected void doRun(Map shardIndexMap) { return; } // must check both node and transport versions to correctly deal with BwC on proxy connections - if (connection.getTransportVersion().before(TransportVersions.BATCHED_QUERY_PHASE_VERSION) - || connection.getNode().getVersionInformation().nodeVersion().before(Version.V_9_1_0)) { + if (connectionSupportsBatchedExecution(connection) == false) { executeWithoutBatching(routing, request); return; } - searchTransportService.transportService() - .sendChildRequest(connection, NODE_SEARCH_ACTION_NAME, request, task, new TransportResponseHandler() { + transportService.sendChildRequest( + connection, + NODE_SEARCH_ACTION_NAME, + request, + task, + new TransportResponseHandler() { @Override public NodeQueryResponse read(StreamInput in) throws IOException { return new NodeQueryResponse(in); @@ -521,10 +576,36 @@ public void handleException(TransportException e) { onPhaseFailure(getName(), "", cause); } } - }); + } + ); }); } + private static ShardSearchRequest buildShardSearchRequestForLocal(NodeQueryRequest nodeQueryRequest, ShardToQuery shardToQuery) { + var shardId = shardToQuery.shardId; + var searchRequest = nodeQueryRequest.searchRequest; + var pitBuilder = searchRequest.pointInTimeBuilder(); + return buildShardSearchRequest( + shardId, + nodeQueryRequest.localClusterAlias, + shardToQuery.shardIndex, + shardToQuery.contextId, + new OriginalIndices(shardToQuery.originalIndices, searchRequest.indicesOptions()), + nodeQueryRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), + pitBuilder == null ? null : pitBuilder.getKeepAlive(), + shardToQuery.boost, + searchRequest, + nodeQueryRequest.totalShards, + nodeQueryRequest.absoluteStartMillis, + false + ); + } + + public static boolean connectionSupportsBatchedExecution(Transport.Connection connection) { + return connection.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION) + && connection.getNode().getVersionInformation().nodeVersion().onOrAfter(Version.V_9_1_0); + } + private void executeWithoutBatching(CanMatchPreFilterSearchPhase.SendingTarget targetNode, NodeQueryRequest request) { for (ShardToQuery shard : request.shards) { executeAsSingleRequest(targetNode, shard); @@ -562,15 +643,39 @@ static void registerNodeSearchAction( final int searchPoolMax = threadPool.info(ThreadPool.Names.SEARCH).getMax(); transportService.registerRequestHandler( NODE_SEARCH_ACTION_NAME, - EsExecutors.DIRECT_EXECUTOR_SERVICE, + threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), NodeQueryRequest::new, (request, channel, task) -> { - final CancellableTask cancellableTask = (CancellableTask) task; + final SearchRequest searchRequest = request.searchRequest; + ShardSearchRequest[] shardSearchRequests = null; + IntUnaryOperator shards = IntUnaryOperator.identity(); final int shardCount = request.shards.size(); - int workers = Math.min(request.searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax)); + if (shardCount > 1 && hasPrimaryFieldSort(searchRequest.source())) { + try { + shardSearchRequests = new ShardSearchRequest[shardCount]; + @SuppressWarnings("rawtypes") + final MinAndMax[] minAndMax = new MinAndMax[shardCount]; + for (int i = 0; i < minAndMax.length; i++) { + ShardSearchRequest r = buildShardSearchRequestForLocal(request, request.shards.get(i)); + shardSearchRequests[i] = r; + var canMatch = searchService.canMatch(r); + if (canMatch.canMatch()) { + r.setRunCanMatchInQueryPhase(false); + minAndMax[i] = canMatch.estimatedMinAndMax(); + } + } + int[] indexes = CanMatchPreFilterSearchPhase.sortShards(request.shards, minAndMax, searchRequest.source()); + shards = pos -> indexes[pos]; + } catch (Exception e) { + // TODO: ignored for now but we'll be guaranteed to fail the query phase at this point, fix things to fail here + // already + } + } + final CancellableTask cancellableTask = (CancellableTask) task; + int workers = Math.min(searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax)); final var state = new QueryPerNodeState( new QueryPhaseResultConsumer( - request.searchRequest, + searchRequest, dependencies.executor, searchService.getCircuitBreaker(), searchPhaseController, @@ -580,9 +685,11 @@ static void registerNodeSearchAction( e -> logger.error("failed to merge on data node", e) ), request, + shards, cancellableTask, channel, - dependencies + dependencies, + shardSearchRequests ); // TODO: log activating or otherwise limiting parallelism might be helpful here for (int i = 0; i < workers; i++) { @@ -646,35 +753,39 @@ private static ShardSearchRequest buildShardSearchRequest( private static void executeShardTasks(QueryPerNodeState state) { int idx; - final int totalShardCount = state.searchRequest.shards.size(); + final NodeQueryRequest nodeQueryRequest = state.searchRequest; + var shards = nodeQueryRequest.shards; + final int totalShardCount = shards.size(); while ((idx = state.currentShardIndex.getAndIncrement()) < totalShardCount) { final int dataNodeLocalIdx = idx; final ListenableFuture doneFuture = new ListenableFuture<>(); try { - final NodeQueryRequest nodeQueryRequest = state.searchRequest; final SearchRequest searchRequest = nodeQueryRequest.searchRequest; var pitBuilder = searchRequest.pointInTimeBuilder(); - var shardToQuery = nodeQueryRequest.shards.get(dataNodeLocalIdx); + int translatedIndex = state.shardsToQuery.applyAsInt(dataNodeLocalIdx); + var shardToQuery = shards.get(translatedIndex); final var shardId = shardToQuery.shardId; + ShardSearchRequest r = state.shardSearchRequests == null ? null : state.shardSearchRequests[translatedIndex]; + if (r == null) { + r = buildShardSearchRequest( + shardId, + nodeQueryRequest.localClusterAlias, + shardToQuery.shardIndex, + shardToQuery.contextId, + new OriginalIndices(shardToQuery.originalIndices, nodeQueryRequest.indicesOptions()), + nodeQueryRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), + pitBuilder == null ? null : pitBuilder.getKeepAlive(), + shardToQuery.boost, + searchRequest, + nodeQueryRequest.totalShards, + nodeQueryRequest.absoluteStartMillis, + state.hasResponse.getAcquire() + ); + } else { + state.shardSearchRequests[translatedIndex] = null; + } state.dependencies.searchService.executeQueryPhase( - tryRewriteWithUpdatedSortValue( - state.bottomSortCollector, - state.trackTotalHitsUpTo, - buildShardSearchRequest( - shardId, - nodeQueryRequest.localClusterAlias, - shardToQuery.shardIndex, - shardToQuery.contextId, - new OriginalIndices(shardToQuery.originalIndices, nodeQueryRequest.indicesOptions()), - nodeQueryRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), - pitBuilder == null ? null : pitBuilder.getKeepAlive(), - shardToQuery.boost, - searchRequest, - nodeQueryRequest.totalShards, - nodeQueryRequest.absoluteStartMillis, - state.hasResponse.getAcquire() - ) - ), + tryRewriteWithUpdatedSortValue(state.bottomSortCollector, state.trackTotalHitsUpTo, r), state.task, new SearchActionListener<>( new SearchShardTarget(null, shardToQuery.shardId, nodeQueryRequest.localClusterAlias), @@ -733,6 +844,7 @@ private static final class QueryPerNodeState { private final AtomicInteger currentShardIndex = new AtomicInteger(); private final QueryPhaseResultConsumer queryPhaseResultConsumer; private final NodeQueryRequest searchRequest; + private final IntUnaryOperator shardsToQuery; private final CancellableTask task; private final ConcurrentHashMap failures = new ConcurrentHashMap<>(); private final Dependencies dependencies; @@ -740,24 +852,29 @@ private static final class QueryPerNodeState { private final int trackTotalHitsUpTo; private final int topDocsSize; private final CountDown countDown; + private final @Nullable ShardSearchRequest[] shardSearchRequests; private final TransportChannel channel; private volatile BottomSortValuesCollector bottomSortCollector; private QueryPerNodeState( QueryPhaseResultConsumer queryPhaseResultConsumer, NodeQueryRequest searchRequest, + IntUnaryOperator shardsToQuery, CancellableTask task, TransportChannel channel, - Dependencies dependencies + Dependencies dependencies, + @Nullable ShardSearchRequest[] shardSearchRequests ) { this.queryPhaseResultConsumer = queryPhaseResultConsumer; this.searchRequest = searchRequest; + this.shardsToQuery = shardsToQuery; this.trackTotalHitsUpTo = searchRequest.searchRequest.resolveTrackTotalHitsUpTo(); this.topDocsSize = getTopDocsSize(searchRequest.searchRequest); this.task = task; this.countDown = new CountDown(queryPhaseResultConsumer.getNumShards()); this.channel = channel; this.dependencies = dependencies; + this.shardSearchRequests = shardSearchRequests; } void onShardDone() { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java index ac23731c38b84..54e355ffe5069 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java @@ -167,13 +167,13 @@ public void runNewSearchPhase( connectionLookup, aliasFilter, concreteIndexBoosts, - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardIterators, timeProvider, task, false, - searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis) + false, + searchService ) .addListener( listener.delegateFailureAndWrap( diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 21eeaedb7ea54..9c83d4337e3aa 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -1482,13 +1482,13 @@ public void runNewSearchPhase( connectionLookup, aliasFilter, concreteIndexBoosts, - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardIterators, timeProvider, task, requireAtLeastOneMatch, - searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis) + false, + searchService ) .addListener( listener.delegateFailureAndWrap( @@ -1568,7 +1568,7 @@ public void runNewSearchPhase( task, clusters, client, - searchService.batchQueryPhase() + searchService ); } success = true; diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java index d12847ec8bf7f..207ce7843526a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java @@ -156,17 +156,7 @@ public void searchShards(Task task, SearchShardsRequest searchShardsRequest, Act CanMatchPreFilterSearchPhase.execute(logger, searchTransportService, (clusterAlias, node) -> { assert Objects.equals(clusterAlias, searchShardsRequest.clusterAlias()); return transportService.getConnection(project.cluster().nodes().get(node)); - }, - aliasFilters, - Map.of(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), - searchRequest, - shardIts, - timeProvider, - (SearchTask) task, - false, - searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis) - ) + }, aliasFilters, Map.of(), searchRequest, shardIts, timeProvider, (SearchTask) task, false, false, searchService) .addListener( delegate.map( its -> new SearchShardsResponse(toGroups(its), project.cluster().nodes().getAllNodes(), aliasFilters) diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 2a814a1a36489..2ddf00e7fefe8 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -689,7 +689,7 @@ public void executeQueryPhase(ShardSearchRequest request, CancellableTask task, threadPool ).delegateFailure((l, orig) -> { // check if we can shortcut the query phase entirely. - if (orig.canReturnNullResponseIfMatchNoDocs()) { + if (orig.canReturnNullResponseIfMatchNoDocs() && orig.runCanMatchInQueryPhase()) { assert orig.scroll() == null; ShardSearchRequest clone = new ShardSearchRequest(orig); CanMatchContext canMatchContext = new CanMatchContext( diff --git a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java index 10d2fb0e23b3b..9bce3d1163d22 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java @@ -101,6 +101,8 @@ public class ShardSearchRequest extends AbstractTransportRequest implements Indi */ private final boolean forceSyntheticSource; + private transient boolean runCanMatchInQueryPhase = true; + public ShardSearchRequest( OriginalIndices originalIndices, SearchRequest searchRequest, @@ -349,6 +351,14 @@ public void writeTo(StreamOutput out) throws IOException { OriginalIndices.writeOriginalIndices(originalIndices, out); } + public void setRunCanMatchInQueryPhase(boolean runCanMatchInQueryPhase) { + this.runCanMatchInQueryPhase = runCanMatchInQueryPhase; + } + + public boolean runCanMatchInQueryPhase() { + return runCanMatchInQueryPhase; + } + protected final void innerWriteTo(StreamOutput out, boolean asKey) throws IOException { shardId.writeTo(out); out.writeByte(searchType.id()); diff --git a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 1c3a6cd47a3b7..f8651c2b76d4c 100644 --- a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -44,6 +44,7 @@ import org.elasticsearch.index.shard.ShardLongFieldRange; import org.elasticsearch.indices.DateFieldRangeInfo; import org.elasticsearch.search.CanMatchShardResponse; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -55,8 +56,8 @@ import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentParserConfiguration; import java.util.ArrayList; @@ -81,7 +82,9 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class CanMatchPreFilterSearchPhaseTests extends ESTestCase { @@ -134,6 +137,11 @@ public void sendCanMatch( new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; AtomicReference> result = new AtomicReference<>(); @@ -155,13 +163,13 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", AliasFilter.EMPTY), Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardsIter, timeProvider, null, true, - EMPTY_CONTEXT_PROVIDER + false, + mockSearchService() ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -185,6 +193,12 @@ public void sendCanMatch( } } + private SearchService mockSearchService() { + var searchService = mock(SearchService.class); + when(searchService.getCoordinatorRewriteContextProvider(any())).thenReturn(EMPTY_CONTEXT_PROVIDER); + return searchService; + } + public void testFilterWithFailure() throws InterruptedException { final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( 0, @@ -228,6 +242,11 @@ public void sendCanMatch( } }).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; AtomicReference> result = new AtomicReference<>(); @@ -250,13 +269,13 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", AliasFilter.EMPTY), Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardsIter, timeProvider, null, true, - EMPTY_CONTEXT_PROVIDER + false, + mockSearchService() ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -319,6 +338,11 @@ public void sendCanMatch( new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; AtomicReference> result = new AtomicReference<>(); @@ -341,13 +365,13 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", AliasFilter.EMPTY), Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardsIter, timeProvider, null, true, - EMPTY_CONTEXT_PROVIDER + false, + mockSearchService() ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -418,6 +442,11 @@ public void sendCanMatch( new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; AtomicReference> result = new AtomicReference<>(); @@ -440,13 +469,13 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", AliasFilter.EMPTY), Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardsIter, timeProvider, null, shardsIter.size() > shardToSkip.size(), - EMPTY_CONTEXT_PROVIDER + false, + mockSearchService() ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -1397,6 +1426,11 @@ public void sendCanMatch( new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( @@ -1405,6 +1439,8 @@ public void sendCanMatch( System::nanoTime ); + var searchService = mock(SearchService.class); + when(searchService.getCoordinatorRewriteContextProvider(any())).thenReturn(contextProvider); return new Tuple<>( CanMatchPreFilterSearchPhase.execute( logger, @@ -1412,18 +1448,24 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), aliasFilters, Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardIters, timeProvider, null, true, - contextProvider + false, + searchService ), requests ); } + private TransportService mockTransportService() { + var transportService = mock(TransportService.class); + when(transportService.getThreadPool()).thenReturn(threadPool); + return transportService; + } + static class StaticCoordinatorRewriteContextProviderBuilder { private ClusterState clusterState = ClusterState.EMPTY_STATE; private final Map fields = new HashMap<>(); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index d7348833c757a..a239680a8e6c4 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.lucene.grouping.TopFieldGroups; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; @@ -207,7 +208,7 @@ public void sendExecuteQuery( task, SearchResponse.Clusters.EMPTY, null, - false + mock(SearchService.class) ) { @Override protected SearchPhase getNextPhase() {