Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 120 additions & 30 deletions server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
*/
package org.elasticsearch.action.search;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
Expand All @@ -27,9 +34,11 @@
import org.elasticsearch.transport.Transport;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.function.Function;
import java.util.Map;

/**
* This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all
Expand All @@ -38,56 +47,50 @@
* operation.
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
*/
final class DfsQueryPhase extends SearchPhase {
class DfsQueryPhase extends SearchPhase {

public static final String NAME = "dfs_query";

private final SearchPhaseResults<SearchPhaseResult> queryResult;
private final List<DfsSearchResult> searchResults;
private final AggregatedDfs dfs;
private final List<DfsKnnResults> knnResults;
private final Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
private final Client client;
private final AbstractSearchAsyncAction<?> context;
private final SearchTransportService searchTransportService;
private final SearchProgressListener progressListener;

DfsQueryPhase(
List<DfsSearchResult> searchResults,
AggregatedDfs dfs,
List<DfsKnnResults> knnResults,
SearchPhaseResults<SearchPhaseResult> queryResult,
Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
AbstractSearchAsyncAction<?> context
) {
DfsQueryPhase(SearchPhaseResults<SearchPhaseResult> queryResult, Client client, AbstractSearchAsyncAction<?> context) {
super(NAME);
this.progressListener = context.getTask().getProgressListener();
this.queryResult = queryResult;
this.searchResults = searchResults;
this.dfs = dfs;
this.knnResults = knnResults;
this.nextPhaseFactory = nextPhaseFactory;
this.client = client;
this.context = context;
this.searchTransportService = context.getSearchTransport();
}

// protected for testing
protected SearchPhase nextPhase(AggregatedDfs dfs) {
return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs);
}

@SuppressWarnings("unchecked")
@Override
protected void run() {
List<DfsSearchResult> searchResults = (List<DfsSearchResult>) context.results.getAtomicArray().asList();
AggregatedDfs dfs = aggregateDfs(searchResults);
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
// to free up memory early
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(
queryResult,
searchResults.size(),
() -> context.executeNextPhase(NAME, () -> nextPhaseFactory.apply(queryResult)),
() -> context.executeNextPhase(NAME, () -> nextPhase(dfs)),
context
);

List<DfsKnnResults> knnResults = mergeKnnResults(context.getRequest(), searchResults);
for (final DfsSearchResult dfsResult : searchResults) {
final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget();
final int shardIndex = dfsResult.getShardIndex();
QuerySearchRequest querySearchRequest = new QuerySearchRequest(
context.getOriginalIndices(shardIndex),
dfsResult.getContextId(),
rewriteShardSearchRequest(dfsResult.getShardSearchRequest()),
rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()),
dfs
);
final Transport.Connection connection;
Expand All @@ -97,11 +100,8 @@ protected void run() {
shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter);
continue;
}
searchTransportService.sendExecuteQuery(
connection,
querySearchRequest,
context.getTask(),
new SearchActionListener<>(shardTarget, shardIndex) {
context.getSearchTransport()
.sendExecuteQuery(connection, querySearchRequest, context.getTask(), new SearchActionListener<>(shardTarget, shardIndex) {

@Override
protected void innerOnResponse(QuerySearchResult response) {
Expand All @@ -126,8 +126,7 @@ public void onFailure(Exception exception) {
}
}
}
}
);
});
}
}

Expand All @@ -144,7 +143,7 @@ private void shardFailure(
}

// package private for testing
ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
ShardSearchRequest rewriteShardSearchRequest(List<DfsKnnResults> knnResults, ShardSearchRequest request) {
SearchSourceBuilder source = request.source();
if (source == null || source.knnSearch().isEmpty()) {
return request;
Expand Down Expand Up @@ -180,4 +179,95 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {

return request;
}

private static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two static methods move without changes, they are now only used from this class.

if (request.hasKnnSearch() == false) {
return null;
}
SearchSourceBuilder source = request.source();
List<List<TopDocs>> topDocsLists = new ArrayList<>(source.knnSearch().size());
List<SetOnce<String>> nestedPath = new ArrayList<>(source.knnSearch().size());
for (int i = 0; i < source.knnSearch().size(); i++) {
topDocsLists.add(new ArrayList<>());
nestedPath.add(new SetOnce<>());
}

for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
if (dfsSearchResult.knnResults() != null) {
for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
ScoreDoc[] scoreDocs = knnResults.scoreDocs();
TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO);
TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
topDocsLists.get(i).add(shardTopDocs);
nestedPath.get(i).trySet(knnResults.getNestedPath());
}
}
}

List<DfsKnnResults> mergedResults = new ArrayList<>(source.knnSearch().size());
for (int i = 0; i < source.knnSearch().size(); i++) {
TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs));
}
return mergedResults;
}

private static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
Map<Term, TermStatistics> termStatistics = new HashMap<>();
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
long aggMaxDoc = 0;
for (DfsSearchResult lEntry : results) {
final Term[] terms = lEntry.terms();
final TermStatistics[] stats = lEntry.termStatistics();
assert terms.length == stats.length;
for (int i = 0; i < terms.length; i++) {
assert terms[i] != null;
if (stats[i] == null) {
continue;
}
TermStatistics existing = termStatistics.get(terms[i]);
if (existing != null) {
assert terms[i].bytes().equals(existing.term());
termStatistics.put(
terms[i],
new TermStatistics(
existing.term(),
existing.docFreq() + stats[i].docFreq(),
existing.totalTermFreq() + stats[i].totalTermFreq()
)
);
} else {
termStatistics.put(terms[i], stats[i]);
}

}

assert lEntry.fieldStatistics().containsKey(null) == false;
for (var entry : lEntry.fieldStatistics().entrySet()) {
String key = entry.getKey();
CollectionStatistics value = entry.getValue();
if (value == null) {
continue;
}
assert key != null;
CollectionStatistics existing = fieldStatistics.get(key);
if (existing != null) {
CollectionStatistics merged = new CollectionStatistics(
key,
existing.maxDoc() + value.maxDoc(),
existing.docCount() + value.docCount(),
existing.sumTotalTermFreq() + value.sumTotalTermFreq(),
existing.sumDocFreq() + value.sumDocFreq()
);
fieldStatistics.put(key, merged);
} else {
fieldStatistics.put(key, value);
}
}
aggMaxDoc += lEntry.maxDoc();
}
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,47 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;

import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;

/**
* This search phase is an optional phase that will be executed once all hits are fetched from the shards that executes
* field-collapsing on the inner hits. This phase only executes if field collapsing is requested in the search request and otherwise
* forwards to the next phase immediately.
*/
final class ExpandSearchPhase extends SearchPhase {
class ExpandSearchPhase extends SearchPhase {

static final String NAME = "expand";

private final AbstractSearchAsyncAction<?> context;
private final SearchHits searchHits;
private final Supplier<SearchPhase> nextPhase;
private final SearchResponseSections searchResponseSections;
private final AtomicArray<SearchPhaseResult> queryPhaseResults;

ExpandSearchPhase(AbstractSearchAsyncAction<?> context, SearchHits searchHits, Supplier<SearchPhase> nextPhase) {
ExpandSearchPhase(
AbstractSearchAsyncAction<?> context,
SearchResponseSections searchResponseSections,
AtomicArray<SearchPhaseResult> queryPhaseResults
) {
super(NAME);
this.context = context;
this.searchHits = searchHits;
this.nextPhase = nextPhase;
this.searchResponseSections = searchResponseSections;
this.queryPhaseResults = queryPhaseResults;
}

// protected for tests
protected SearchPhase nextPhase() {
return new FetchLookupFieldsPhase(context, searchResponseSections, queryPhaseResults);
}

/**
Expand All @@ -55,14 +65,15 @@ private boolean isCollapseRequest() {

@Override
protected void run() {
var searchHits = searchResponseSections.hits();
if (isCollapseRequest() == false || searchHits.getHits().length == 0) {
onPhaseDone();
} else {
doRun();
doRun(searchHits);
}
}

private void doRun() {
private void doRun(SearchHits searchHits) {
SearchRequest searchRequest = context.getRequest();
CollapseBuilder collapseBuilder = searchRequest.source().collapse();
final List<InnerHitBuilder> innerHitBuilders = collapseBuilder.getInnerHits();
Expand Down Expand Up @@ -171,6 +182,6 @@ private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilde
}

private void onPhaseDone() {
context.executeNextPhase(NAME, nextPhase);
context.executeNextPhase(NAME, this::nextPhase);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ final class FetchLookupFieldsPhase extends SearchPhase {
this.queryResults = queryResults;
}

private record Cluster(String clusterAlias, List<SearchHit> hitsWithLookupFields, List<LookupField> lookupFields) {

}
private record Cluster(String clusterAlias, List<SearchHit> hitsWithLookupFields, List<LookupField> lookupFields) {}

private static List<Cluster> groupLookupFieldsByClusterAlias(SearchHits searchHits) {
final Map<String, List<SearchHit>> perClusters = new HashMap<>();
Expand All @@ -80,7 +78,7 @@ private static List<Cluster> groupLookupFieldsByClusterAlias(SearchHits searchHi
protected void run() {
final List<Cluster> clusters = groupLookupFieldsByClusterAlias(searchResponse.hits);
if (clusters.isEmpty()) {
context.sendSearchResponse(searchResponse, queryResults);
sendResponse();
return;
}
doRun(clusters);
Expand Down Expand Up @@ -132,9 +130,9 @@ public void onResponse(MultiSearchResponse items) {
}
}
if (failure != null) {
context.onPhaseFailure(NAME, "failed to fetch lookup fields", failure);
onFailure(failure);
} else {
context.sendSearchResponse(searchResponse, queryResults);
sendResponse();
}
}

Expand All @@ -144,4 +142,8 @@ public void onFailure(Exception e) {
}
});
}

private void sendResponse() {
context.sendSearchResponse(searchResponse, queryResults);
}
}
Loading