88 */
99package org .elasticsearch .action .search ;
1010
11+ import org .apache .lucene .index .Term ;
12+ import org .apache .lucene .search .CollectionStatistics ;
1113import org .apache .lucene .search .ScoreDoc ;
14+ import org .apache .lucene .search .TermStatistics ;
15+ import org .apache .lucene .search .TopDocs ;
16+ import org .apache .lucene .search .TotalHits ;
1217import org .apache .lucene .search .join .ScoreMode ;
18+ import org .apache .lucene .util .SetOnce ;
19+ import org .elasticsearch .client .internal .Client ;
1320import org .elasticsearch .common .lucene .Lucene ;
1421import org .elasticsearch .index .query .NestedQueryBuilder ;
1522import org .elasticsearch .index .query .QueryBuilder ;
2734import org .elasticsearch .transport .Transport ;
2835
2936import java .util .ArrayList ;
37+ import java .util .Collection ;
3038import java .util .Comparator ;
39+ import java .util .HashMap ;
3140import java .util .List ;
32- import java .util .function . Function ;
41+ import java .util .Map ;
3342
3443/**
3544 * This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all
3847 * operation.
3948 * @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
4049 */
41- final class DfsQueryPhase extends SearchPhase {
50+ class DfsQueryPhase extends SearchPhase {
4251
4352 public static final String NAME = "dfs_query" ;
4453
4554 private final SearchPhaseResults <SearchPhaseResult > queryResult ;
46- private final List <DfsSearchResult > searchResults ;
47- private final AggregatedDfs dfs ;
48- private final List <DfsKnnResults > knnResults ;
49- private final Function <SearchPhaseResults <SearchPhaseResult >, SearchPhase > nextPhaseFactory ;
55+ private final Client client ;
5056 private final AbstractSearchAsyncAction <?> context ;
51- private final SearchTransportService searchTransportService ;
5257 private final SearchProgressListener progressListener ;
5358
54- DfsQueryPhase (
55- List <DfsSearchResult > searchResults ,
56- AggregatedDfs dfs ,
57- List <DfsKnnResults > knnResults ,
58- SearchPhaseResults <SearchPhaseResult > queryResult ,
59- Function <SearchPhaseResults <SearchPhaseResult >, SearchPhase > nextPhaseFactory ,
60- AbstractSearchAsyncAction <?> context
61- ) {
59+ DfsQueryPhase (SearchPhaseResults <SearchPhaseResult > queryResult , Client client , AbstractSearchAsyncAction <?> context ) {
6260 super (NAME );
6361 this .progressListener = context .getTask ().getProgressListener ();
6462 this .queryResult = queryResult ;
65- this .searchResults = searchResults ;
66- this .dfs = dfs ;
67- this .knnResults = knnResults ;
68- this .nextPhaseFactory = nextPhaseFactory ;
63+ this .client = client ;
6964 this .context = context ;
70- this .searchTransportService = context .getSearchTransport ();
7165 }
7266
67+ // protected for testing
68+ protected SearchPhase nextPhase (AggregatedDfs dfs ) {
69+ return SearchQueryThenFetchAsyncAction .nextPhase (client , context , queryResult , dfs );
70+ }
71+
72+ @ SuppressWarnings ("unchecked" )
7373 @ Override
7474 protected void run () {
75+ List <DfsSearchResult > searchResults = (List <DfsSearchResult >) context .results .getAtomicArray ().asList ();
76+ AggregatedDfs dfs = aggregateDfs (searchResults );
7577 // TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
7678 // to free up memory early
7779 final CountedCollector <SearchPhaseResult > counter = new CountedCollector <>(
7880 queryResult ,
7981 searchResults .size (),
80- () -> context .executeNextPhase (NAME , () -> nextPhaseFactory . apply ( queryResult )),
82+ () -> context .executeNextPhase (NAME , () -> nextPhase ( dfs )),
8183 context
8284 );
8385
86+ List <DfsKnnResults > knnResults = mergeKnnResults (context .getRequest (), searchResults );
8487 for (final DfsSearchResult dfsResult : searchResults ) {
8588 final SearchShardTarget shardTarget = dfsResult .getSearchShardTarget ();
8689 final int shardIndex = dfsResult .getShardIndex ();
8790 QuerySearchRequest querySearchRequest = new QuerySearchRequest (
8891 context .getOriginalIndices (shardIndex ),
8992 dfsResult .getContextId (),
90- rewriteShardSearchRequest (dfsResult .getShardSearchRequest ()),
93+ rewriteShardSearchRequest (knnResults , dfsResult .getShardSearchRequest ()),
9194 dfs
9295 );
9396 final Transport .Connection connection ;
@@ -97,11 +100,8 @@ protected void run() {
97100 shardFailure (e , querySearchRequest , shardIndex , shardTarget , counter );
98101 continue ;
99102 }
100- searchTransportService .sendExecuteQuery (
101- connection ,
102- querySearchRequest ,
103- context .getTask (),
104- new SearchActionListener <>(shardTarget , shardIndex ) {
103+ context .getSearchTransport ()
104+ .sendExecuteQuery (connection , querySearchRequest , context .getTask (), new SearchActionListener <>(shardTarget , shardIndex ) {
105105
106106 @ Override
107107 protected void innerOnResponse (QuerySearchResult response ) {
@@ -126,8 +126,7 @@ public void onFailure(Exception exception) {
126126 }
127127 }
128128 }
129- }
130- );
129+ });
131130 }
132131 }
133132
@@ -144,7 +143,7 @@ private void shardFailure(
144143 }
145144
146145 // package private for testing
147- ShardSearchRequest rewriteShardSearchRequest (ShardSearchRequest request ) {
146+ ShardSearchRequest rewriteShardSearchRequest (List < DfsKnnResults > knnResults , ShardSearchRequest request ) {
148147 SearchSourceBuilder source = request .source ();
149148 if (source == null || source .knnSearch ().isEmpty ()) {
150149 return request ;
@@ -180,4 +179,95 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
180179
181180 return request ;
182181 }
182+
183+ private static List <DfsKnnResults > mergeKnnResults (SearchRequest request , List <DfsSearchResult > dfsSearchResults ) {
184+ if (request .hasKnnSearch () == false ) {
185+ return null ;
186+ }
187+ SearchSourceBuilder source = request .source ();
188+ List <List <TopDocs >> topDocsLists = new ArrayList <>(source .knnSearch ().size ());
189+ List <SetOnce <String >> nestedPath = new ArrayList <>(source .knnSearch ().size ());
190+ for (int i = 0 ; i < source .knnSearch ().size (); i ++) {
191+ topDocsLists .add (new ArrayList <>());
192+ nestedPath .add (new SetOnce <>());
193+ }
194+
195+ for (DfsSearchResult dfsSearchResult : dfsSearchResults ) {
196+ if (dfsSearchResult .knnResults () != null ) {
197+ for (int i = 0 ; i < dfsSearchResult .knnResults ().size (); i ++) {
198+ DfsKnnResults knnResults = dfsSearchResult .knnResults ().get (i );
199+ ScoreDoc [] scoreDocs = knnResults .scoreDocs ();
200+ TotalHits totalHits = new TotalHits (scoreDocs .length , TotalHits .Relation .EQUAL_TO );
201+ TopDocs shardTopDocs = new TopDocs (totalHits , scoreDocs );
202+ SearchPhaseController .setShardIndex (shardTopDocs , dfsSearchResult .getShardIndex ());
203+ topDocsLists .get (i ).add (shardTopDocs );
204+ nestedPath .get (i ).trySet (knnResults .getNestedPath ());
205+ }
206+ }
207+ }
208+
209+ List <DfsKnnResults > mergedResults = new ArrayList <>(source .knnSearch ().size ());
210+ for (int i = 0 ; i < source .knnSearch ().size (); i ++) {
211+ TopDocs mergedTopDocs = TopDocs .merge (source .knnSearch ().get (i ).k (), topDocsLists .get (i ).toArray (new TopDocs [0 ]));
212+ mergedResults .add (new DfsKnnResults (nestedPath .get (i ).get (), mergedTopDocs .scoreDocs ));
213+ }
214+ return mergedResults ;
215+ }
216+
217+ private static AggregatedDfs aggregateDfs (Collection <DfsSearchResult > results ) {
218+ Map <Term , TermStatistics > termStatistics = new HashMap <>();
219+ Map <String , CollectionStatistics > fieldStatistics = new HashMap <>();
220+ long aggMaxDoc = 0 ;
221+ for (DfsSearchResult lEntry : results ) {
222+ final Term [] terms = lEntry .terms ();
223+ final TermStatistics [] stats = lEntry .termStatistics ();
224+ assert terms .length == stats .length ;
225+ for (int i = 0 ; i < terms .length ; i ++) {
226+ assert terms [i ] != null ;
227+ if (stats [i ] == null ) {
228+ continue ;
229+ }
230+ TermStatistics existing = termStatistics .get (terms [i ]);
231+ if (existing != null ) {
232+ assert terms [i ].bytes ().equals (existing .term ());
233+ termStatistics .put (
234+ terms [i ],
235+ new TermStatistics (
236+ existing .term (),
237+ existing .docFreq () + stats [i ].docFreq (),
238+ existing .totalTermFreq () + stats [i ].totalTermFreq ()
239+ )
240+ );
241+ } else {
242+ termStatistics .put (terms [i ], stats [i ]);
243+ }
244+
245+ }
246+
247+ assert lEntry .fieldStatistics ().containsKey (null ) == false ;
248+ for (var entry : lEntry .fieldStatistics ().entrySet ()) {
249+ String key = entry .getKey ();
250+ CollectionStatistics value = entry .getValue ();
251+ if (value == null ) {
252+ continue ;
253+ }
254+ assert key != null ;
255+ CollectionStatistics existing = fieldStatistics .get (key );
256+ if (existing != null ) {
257+ CollectionStatistics merged = new CollectionStatistics (
258+ key ,
259+ existing .maxDoc () + value .maxDoc (),
260+ existing .docCount () + value .docCount (),
261+ existing .sumTotalTermFreq () + value .sumTotalTermFreq (),
262+ existing .sumDocFreq () + value .sumDocFreq ()
263+ );
264+ fieldStatistics .put (key , merged );
265+ } else {
266+ fieldStatistics .put (key , value );
267+ }
268+ }
269+ aggMaxDoc += lEntry .maxDoc ();
270+ }
271+ return new AggregatedDfs (termStatistics , fieldStatistics , aggMaxDoc );
272+ }
183273}
0 commit comments