Skip to content

Commit 8f4a650

Browse files
Introduce batched query execution and data-node side reduce
Shortest version I could think of. Still WIP, have to make some test adjustments and polish rough edges, but it shouldn't get longer than this.
1 parent a9d6c12 commit 8f4a650

File tree

22 files changed

+1090
-92
lines changed

22 files changed

+1090
-92
lines changed

server/src/internalClusterTest/java/org/elasticsearch/action/IndicesRequestIT.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -574,11 +574,8 @@ public void testSearchQueryThenFetch() throws Exception {
574574
);
575575

576576
clearInterceptedActions();
577-
assertIndicesSubset(
578-
Arrays.asList(searchRequest.indices()),
579-
SearchTransportService.QUERY_ACTION_NAME,
580-
SearchTransportService.FETCH_ID_ACTION_NAME
581-
);
577+
assertIndicesSubset(Arrays.asList(searchRequest.indices()), true, SearchTransportService.QUERY_ACTION_NAME);
578+
assertIndicesSubset(Arrays.asList(searchRequest.indices()), SearchTransportService.FETCH_ID_ACTION_NAME);
582579
}
583580

584581
public void testSearchDfsQueryThenFetch() throws Exception {
@@ -631,10 +628,6 @@ private static void assertIndicesSubset(List<String> indices, String... actions)
631628
assertIndicesSubset(indices, false, actions);
632629
}
633630

634-
private static void assertIndicesSubsetOptionalRequests(List<String> indices, String... actions) {
635-
assertIndicesSubset(indices, true, actions);
636-
}
637-
638631
private static void assertIndicesSubset(List<String> indices, boolean optional, String... actions) {
639632
// indices returned by each bulk shard request need to be a subset of the original indices
640633
for (String action : actions) {

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ static TransportVersion def(int id) {
180180
public static final TransportVersion REMOVE_DESIRED_NODE_VERSION = def(9_004_0_00);
181181
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION = def(9_005_0_00);
182182
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE = def(9_006_0_00);
183+
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION = def(9_007_0_00);
183184

184185
/*
185186
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,34 +67,34 @@
6767
* distributed frequencies
6868
*/
6969
abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> extends SearchPhase {
70-
private static final float DEFAULT_INDEX_BOOST = 1.0f;
70+
protected static final float DEFAULT_INDEX_BOOST = 1.0f;
7171
private final Logger logger;
7272
private final NamedWriteableRegistry namedWriteableRegistry;
73-
private final SearchTransportService searchTransportService;
73+
protected final SearchTransportService searchTransportService;
7474
private final Executor executor;
7575
private final ActionListener<SearchResponse> listener;
76-
private final SearchRequest request;
76+
protected final SearchRequest request;
7777

7878
/**
7979
* Used by subclasses to resolve node ids to DiscoveryNodes.
8080
**/
8181
private final BiFunction<String, String, Transport.Connection> nodeIdToConnection;
82-
private final SearchTask task;
82+
protected final SearchTask task;
8383
protected final SearchPhaseResults<Result> results;
8484
private final long clusterStateVersion;
85-
private final TransportVersion minTransportVersion;
86-
private final Map<String, AliasFilter> aliasFilter;
87-
private final Map<String, Float> concreteIndexBoosts;
88-
private final SetOnce<AtomicArray<ShardSearchFailure>> shardFailures = new SetOnce<>();
85+
protected final TransportVersion minTransportVersion;
86+
protected final Map<String, AliasFilter> aliasFilter;
87+
protected final Map<String, Float> concreteIndexBoosts;
88+
protected final SetOnce<AtomicArray<ShardSearchFailure>> shardFailures = new SetOnce<>();
8989
private final Object shardFailuresMutex = new Object();
90-
private final AtomicBoolean hasShardResponse = new AtomicBoolean(false);
91-
private final AtomicInteger successfulOps = new AtomicInteger();
92-
private final SearchTimeProvider timeProvider;
90+
protected final AtomicBoolean hasShardResponse = new AtomicBoolean(false);
91+
protected final AtomicInteger successfulOps = new AtomicInteger();
92+
protected final SearchTimeProvider timeProvider;
9393
private final SearchResponse.Clusters clusters;
9494

9595
protected final List<SearchShardIterator> toSkipShardsIts;
9696
protected final List<SearchShardIterator> shardsIts;
97-
private final SearchShardIterator[] shardIterators;
97+
protected final SearchShardIterator[] shardIterators;
9898
private final AtomicInteger outstandingShards;
9999
private final int maxConcurrentRequestsPerNode;
100100
private final Map<String, PendingExecutions> pendingExecutionsPerNode = new ConcurrentHashMap<>();
@@ -214,7 +214,7 @@ public final void start() {
214214
}
215215

216216
@Override
217-
protected final void run() {
217+
protected void run() {
218218
for (final SearchShardIterator iterator : toSkipShardsIts) {
219219
assert iterator.skip();
220220
skipShard(iterator);
@@ -290,7 +290,7 @@ private void doPerformPhaseOnShard(int shardIndex, SearchShardIterator shardIt,
290290
public void innerOnResponse(Result result) {
291291
try {
292292
releasable.close();
293-
onShardResult(result, shardIt);
293+
onShardResult(result);
294294
} catch (Exception exc) {
295295
onShardFailure(shardIndex, shard, shardIt, exc);
296296
}
@@ -407,7 +407,7 @@ private void executePhase(SearchPhase phase) {
407407
}
408408
}
409409

410-
private ShardSearchFailure[] buildShardFailures() {
410+
protected ShardSearchFailure[] buildShardFailures() {
411411
AtomicArray<ShardSearchFailure> shardFailures = this.shardFailures.get();
412412
if (shardFailures == null) {
413413
return ShardSearchFailure.EMPTY_ARRAY;
@@ -420,7 +420,7 @@ private ShardSearchFailure[] buildShardFailures() {
420420
return failures;
421421
}
422422

423-
private void onShardFailure(final int shardIndex, SearchShardTarget shard, final SearchShardIterator shardIt, Exception e) {
423+
protected void onShardFailure(final int shardIndex, SearchShardTarget shard, final SearchShardIterator shardIt, Exception e) {
424424
// we always add the shard failure for a specific shard instance
425425
// we do make sure to clean it on a successful response from a shard
426426
onShardFailure(shardIndex, shard, e);
@@ -513,9 +513,8 @@ private static boolean isTaskCancelledException(Exception e) {
513513
/**
514514
* Executed once for every successful shard level request.
515515
* @param result the result returned form the shard
516-
* @param shardIt the shard iterator
517516
*/
518-
protected void onShardResult(Result result, SearchShardIterator shardIt) {
517+
protected void onShardResult(Result result) {
519518
assert result.getShardIndex() != -1 : "shard index is not set";
520519
assert result.getSearchShardTarget() != null : "search shard target must not be null";
521520
hasShardResponse.set(true);
@@ -705,7 +704,7 @@ void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connecti
705704
/**
706705
* Executed once all shard results have been received and processed
707706
* @see #onShardFailure(int, SearchShardTarget, Exception)
708-
* @see #onShardResult(SearchPhaseResult, SearchShardIterator)
707+
* @see #onShardResult(SearchPhaseResult)
709708
*/
710709
private void onPhaseDone() { // as a tribute to @kimchy aka. finishHim()
711710
executeNextPhase(getName(), this::getNextPhase);

server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java

Lines changed: 104 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@
1616
import org.elasticsearch.common.breaker.CircuitBreaker;
1717
import org.elasticsearch.common.breaker.CircuitBreakingException;
1818
import org.elasticsearch.common.io.stream.DelayableWriteable;
19+
import org.elasticsearch.common.io.stream.StreamInput;
20+
import org.elasticsearch.common.io.stream.StreamOutput;
21+
import org.elasticsearch.common.io.stream.Writeable;
22+
import org.elasticsearch.common.lucene.Lucene;
1923
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
2024
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
25+
import org.elasticsearch.core.Nullable;
26+
import org.elasticsearch.core.Tuple;
27+
import org.elasticsearch.index.shard.ShardId;
2128
import org.elasticsearch.search.SearchPhaseResult;
2229
import org.elasticsearch.search.SearchService;
2330
import org.elasticsearch.search.SearchShardTarget;
@@ -27,6 +34,7 @@
2734
import org.elasticsearch.search.query.QuerySearchResult;
2835
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
2936

37+
import java.io.IOException;
3038
import java.util.ArrayDeque;
3139
import java.util.ArrayList;
3240
import java.util.Collections;
@@ -66,7 +74,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
6674
private final Consumer<Exception> onPartialMergeFailure;
6775

6876
private final int batchReduceSize;
69-
private List<QuerySearchResult> buffer = new ArrayList<>();
77+
List<QuerySearchResult> buffer = new ArrayList<>();
7078
private List<SearchShard> emptyResults = new ArrayList<>();
7179
// the memory that is accounted in the circuit breaker for this consumer
7280
private volatile long circuitBreakerBytes;
@@ -76,9 +84,9 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
7684

7785
private final ArrayDeque<MergeTask> queue = new ArrayDeque<>();
7886
private final AtomicReference<MergeTask> runningTask = new AtomicReference<>();
79-
private final AtomicReference<Exception> failure = new AtomicReference<>();
87+
public final AtomicReference<Exception> failure = new AtomicReference<>();
8088

81-
private final TopDocsStats topDocsStats;
89+
public final TopDocsStats topDocsStats;
8290
private volatile MergeResult mergeResult;
8391
private volatile boolean hasPartialReduce;
8492
private volatile int numReducePhases;
@@ -149,6 +157,33 @@ public void consumeResult(SearchPhaseResult result, Runnable next) {
149157
consume(querySearchResult, next);
150158
}
151159

160+
private final List<Tuple<TopDocsStats, MergeResult>> batchedResults = new ArrayList<>();
161+
162+
public MergeResult consumePartialResult() {
163+
var mergeResult = this.mergeResult;
164+
this.mergeResult = null;
165+
assert runningTask.get() == null;
166+
final List<QuerySearchResult> buffer;
167+
synchronized (this) {
168+
buffer = this.buffer;
169+
}
170+
if (buffer != null && buffer.isEmpty() == false) {
171+
this.buffer = null;
172+
buffer.sort(RESULT_COMPARATOR);
173+
mergeResult = partialReduce(buffer, emptyResults, topDocsStats, mergeResult, numReducePhases++);
174+
emptyResults = null;
175+
}
176+
return mergeResult;
177+
}
178+
179+
public void addPartialResult(TopDocsStats topDocsStats, MergeResult mergeResult) {
180+
if (mergeResult.processedShards.isEmpty() == false) {
181+
synchronized (batchedResults) {
182+
batchedResults.add(new Tuple<>(topDocsStats, mergeResult));
183+
}
184+
}
185+
}
186+
152187
@Override
153188
public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
154189
if (hasPendingMerges()) {
@@ -171,19 +206,26 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
171206
buffer.sort(RESULT_COMPARATOR);
172207
final TopDocsStats topDocsStats = this.topDocsStats;
173208
var mergeResult = this.mergeResult;
174-
this.mergeResult = null;
175-
final int resultSize = buffer.size() + (mergeResult == null ? 0 : 1);
209+
final List<Tuple<TopDocsStats, MergeResult>> batchedResults;
210+
synchronized (this.batchedResults) {
211+
batchedResults = this.batchedResults;
212+
}
213+
final int resultSize = buffer.size() + (mergeResult == null ? 0 : 1) + batchedResults.size();
176214
final List<TopDocs> topDocsList = hasTopDocs ? new ArrayList<>(resultSize) : null;
177215
final List<DelayableWriteable<InternalAggregations>> aggsList = hasAggs ? new ArrayList<>(resultSize) : null;
178216
if (mergeResult != null) {
179-
if (topDocsList != null) {
180-
topDocsList.add(mergeResult.reducedTopDocs);
181-
}
182-
if (aggsList != null) {
183-
aggsList.add(DelayableWriteable.referencing(mergeResult.reducedAggs));
184-
}
217+
this.mergeResult = null;
218+
consumePartialMergeResult(mergeResult, topDocsList, aggsList);
219+
}
220+
for (int i = 0; i < batchedResults.size(); i++) {
221+
Tuple<TopDocsStats, MergeResult> batchedResult = batchedResults.set(i, null);
222+
consumePartialMergeResult(batchedResult.v2(), topDocsList, aggsList);
223+
topDocsStats.add(batchedResult.v1());
185224
}
186225
for (QuerySearchResult result : buffer) {
226+
if (result.isReduced()) {
227+
continue;
228+
}
187229
topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
188230
if (topDocsList != null) {
189231
TopDocsAndMaxScore topDocs = result.consumeTopDocs();
@@ -236,6 +278,19 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
236278

237279
}
238280

281+
private static void consumePartialMergeResult(
282+
MergeResult partialResult,
283+
List<TopDocs> topDocsList,
284+
List<DelayableWriteable<InternalAggregations>> aggsList
285+
) {
286+
if (topDocsList != null) {
287+
topDocsList.add(partialResult.reducedTopDocs);
288+
}
289+
if (aggsList != null) {
290+
aggsList.add(DelayableWriteable.referencing(partialResult.reducedAggs));
291+
}
292+
}
293+
239294
private static final Comparator<QuerySearchResult> RESULT_COMPARATOR = Comparator.comparingInt(QuerySearchResult::getShardIndex);
240295

241296
private MergeResult partialReduce(
@@ -284,12 +339,15 @@ private MergeResult partialReduce(
284339
}
285340
}
286341
// we have to merge here in the same way we collect on a shard
287-
newTopDocs = topDocsList == null ? null : mergeTopDocs(topDocsList, topNSize, 0);
342+
newTopDocs = topDocsList == null ? Lucene.EMPTY_TOP_DOCS : mergeTopDocs(topDocsList, topNSize, 0);
288343
newAggs = aggsList == null
289344
? null
290345
: InternalAggregations.topLevelReduceDelayable(aggsList, aggReduceContextBuilder.forPartialReduction());
291346
} finally {
292347
releaseAggs(toConsume);
348+
for (QuerySearchResult querySearchResult : toConsume) {
349+
querySearchResult.setReduced();
350+
}
293351
}
294352
if (lastMerge != null) {
295353
processedShards.addAll(lastMerge.processedShards);
@@ -306,7 +364,7 @@ public int getNumReducePhases() {
306364
return numReducePhases;
307365
}
308366

309-
private boolean hasFailure() {
367+
public boolean hasFailure() {
310368
return failure.get() != null;
311369
}
312370

@@ -351,8 +409,15 @@ private void consume(QuerySearchResult result, Runnable next) {
351409
if (hasFailure()) {
352410
result.consumeAll();
353411
next.run();
354-
} else if (result.isNull()) {
355-
result.consumeAll();
412+
} else if (result.isNull() || result.isReduced()) {
413+
if (result.isReduced()) {
414+
if (result.hasConsumedTopDocs() == false) {
415+
result.consumeTopDocs();
416+
}
417+
result.releaseAggs();
418+
} else {
419+
result.consumeAll();
420+
}
356421
SearchShardTarget target = result.getSearchShardTarget();
357422
SearchShard searchShard = new SearchShard(target.getClusterAlias(), target.getShardId());
358423
synchronized (this) {
@@ -522,12 +587,33 @@ private static void releaseAggs(List<QuerySearchResult> toConsume) {
522587
}
523588
}
524589

525-
private record MergeResult(
590+
public record MergeResult(
526591
List<SearchShard> processedShards,
527592
TopDocs reducedTopDocs,
528-
InternalAggregations reducedAggs,
593+
@Nullable InternalAggregations reducedAggs,
529594
long estimatedSize
530-
) {}
595+
) implements Writeable {
596+
597+
static MergeResult readFrom(StreamInput in) throws IOException {
598+
return new MergeResult(
599+
in.readCollectionAsImmutableList(i -> new SearchShard(i.readOptionalString(), new ShardId(i))),
600+
Lucene.readTopDocsOnly(in),
601+
in.readOptionalWriteable(InternalAggregations::readFrom),
602+
in.readVLong()
603+
);
604+
}
605+
606+
@Override
607+
public void writeTo(StreamOutput out) throws IOException {
608+
out.writeCollection(processedShards, (o, s) -> {
609+
o.writeOptionalString(s.clusterAlias());
610+
s.shardId().writeTo(o);
611+
});
612+
Lucene.writeTopDocsIncludingShardIndex(out, reducedTopDocs);
613+
out.writeOptionalWriteable(reducedAggs);
614+
out.writeVLong(estimatedSize);
615+
}
616+
}
531617

532618
private static class MergeTask {
533619
private final List<SearchShard> emptyResults;

0 commit comments

Comments
 (0)