88 */
99package org .elasticsearch .action .search ;
1010
11+ import org .apache .lucene .index .Term ;
12+ import org .apache .lucene .search .CollectionStatistics ;
1113import org .apache .lucene .search .ScoreDoc ;
14+ import org .apache .lucene .search .TermStatistics ;
15+ import org .apache .lucene .search .TopDocs ;
16+ import org .apache .lucene .search .TotalHits ;
1217import org .apache .lucene .search .join .ScoreMode ;
18+ import org .apache .lucene .util .SetOnce ;
19+ import org .elasticsearch .client .internal .Client ;
1320import org .elasticsearch .common .lucene .Lucene ;
1421import org .elasticsearch .index .query .NestedQueryBuilder ;
1522import org .elasticsearch .index .query .QueryBuilder ;
2734import org .elasticsearch .transport .Transport ;
2835
2936import java .util .ArrayList ;
37+ import java .util .Collection ;
3038import java .util .Comparator ;
39+ import java .util .HashMap ;
3140import java .util .List ;
32- import java .util .function . Function ;
41+ import java .util .Map ;
3342
3443/**
3544 * This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all
3847 * operation.
3948 * @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
4049 */
41- final class DfsQueryPhase extends SearchPhase {
50+ class DfsQueryPhase extends SearchPhase {
51+
52+ public static final String NAME = "dfs_query" ;
53+
4254 private final SearchPhaseResults <SearchPhaseResult > queryResult ;
43- private final List <DfsSearchResult > searchResults ;
44- private final AggregatedDfs dfs ;
45- private final List <DfsKnnResults > knnResults ;
46- private final Function <SearchPhaseResults <SearchPhaseResult >, SearchPhase > nextPhaseFactory ;
55+ private final Client client ;
4756 private final AbstractSearchAsyncAction <?> context ;
48- private final SearchTransportService searchTransportService ;
4957 private final SearchProgressListener progressListener ;
5058
51- DfsQueryPhase (
52- List <DfsSearchResult > searchResults ,
53- AggregatedDfs dfs ,
54- List <DfsKnnResults > knnResults ,
55- SearchPhaseResults <SearchPhaseResult > queryResult ,
56- Function <SearchPhaseResults <SearchPhaseResult >, SearchPhase > nextPhaseFactory ,
57- AbstractSearchAsyncAction <?> context
58- ) {
59- super ("dfs_query" );
59+ DfsQueryPhase (SearchPhaseResults <SearchPhaseResult > queryResult , Client client , AbstractSearchAsyncAction <?> context ) {
60+ super (NAME );
6061 this .progressListener = context .getTask ().getProgressListener ();
6162 this .queryResult = queryResult ;
62- this .searchResults = searchResults ;
63- this .dfs = dfs ;
64- this .knnResults = knnResults ;
65- this .nextPhaseFactory = nextPhaseFactory ;
63+ this .client = client ;
6664 this .context = context ;
67- this .searchTransportService = context .getSearchTransport ();
6865 }
6966
67+ // protected for testing
68+ protected SearchPhase nextPhase (AggregatedDfs dfs ) {
69+ return SearchQueryThenFetchAsyncAction .nextPhase (client , context , queryResult , dfs );
70+ }
71+
72+ @ SuppressWarnings ("unchecked" )
7073 @ Override
71- public void run () {
74+ protected void run () {
75+ List <DfsSearchResult > searchResults = (List <DfsSearchResult >) context .results .getAtomicArray ().asList ();
76+ AggregatedDfs dfs = aggregateDfs (searchResults );
7277 // TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
7378 // to free up memory early
7479 final CountedCollector <SearchPhaseResult > counter = new CountedCollector <>(
7580 queryResult ,
7681 searchResults .size (),
77- () -> context .executeNextPhase (this , () -> nextPhaseFactory . apply ( queryResult )),
82+ () -> context .executeNextPhase (NAME , () -> nextPhase ( dfs )),
7883 context
7984 );
8085
86+ List <DfsKnnResults > knnResults = mergeKnnResults (context .getRequest (), searchResults );
8187 for (final DfsSearchResult dfsResult : searchResults ) {
8288 final SearchShardTarget shardTarget = dfsResult .getSearchShardTarget ();
8389 final int shardIndex = dfsResult .getShardIndex ();
8490 QuerySearchRequest querySearchRequest = new QuerySearchRequest (
8591 context .getOriginalIndices (shardIndex ),
8692 dfsResult .getContextId (),
87- rewriteShardSearchRequest (dfsResult .getShardSearchRequest ()),
93+ rewriteShardSearchRequest (knnResults , dfsResult .getShardSearchRequest ()),
8894 dfs
8995 );
9096 final Transport .Connection connection ;
@@ -94,19 +100,16 @@ public void run() {
94100 shardFailure (e , querySearchRequest , shardIndex , shardTarget , counter );
95101 continue ;
96102 }
97- searchTransportService .sendExecuteQuery (
98- connection ,
99- querySearchRequest ,
100- context .getTask (),
101- new SearchActionListener <>(shardTarget , shardIndex ) {
103+ context .getSearchTransport ()
104+ .sendExecuteQuery (connection , querySearchRequest , context .getTask (), new SearchActionListener <>(shardTarget , shardIndex ) {
102105
103106 @ Override
104107 protected void innerOnResponse (QuerySearchResult response ) {
105108 try {
106109 response .setSearchProfileDfsPhaseResult (dfsResult .searchProfileDfsPhaseResult ());
107110 counter .onResult (response );
108111 } catch (Exception e ) {
109- context .onPhaseFailure (DfsQueryPhase . this , "" , e );
112+ context .onPhaseFailure (NAME , "" , e );
110113 }
111114 }
112115
@@ -123,8 +126,7 @@ public void onFailure(Exception exception) {
123126 }
124127 }
125128 }
126- }
127- );
129+ });
128130 }
129131 }
130132
@@ -141,7 +143,7 @@ private void shardFailure(
141143 }
142144
143145 // package private for testing
144- ShardSearchRequest rewriteShardSearchRequest (ShardSearchRequest request ) {
146+ ShardSearchRequest rewriteShardSearchRequest (List < DfsKnnResults > knnResults , ShardSearchRequest request ) {
145147 SearchSourceBuilder source = request .source ();
146148 if (source == null || source .knnSearch ().isEmpty ()) {
147149 return request ;
@@ -177,4 +179,95 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
177179
178180 return request ;
179181 }
182+
183+ private static List <DfsKnnResults > mergeKnnResults (SearchRequest request , List <DfsSearchResult > dfsSearchResults ) {
184+ if (request .hasKnnSearch () == false ) {
185+ return null ;
186+ }
187+ SearchSourceBuilder source = request .source ();
188+ List <List <TopDocs >> topDocsLists = new ArrayList <>(source .knnSearch ().size ());
189+ List <SetOnce <String >> nestedPath = new ArrayList <>(source .knnSearch ().size ());
190+ for (int i = 0 ; i < source .knnSearch ().size (); i ++) {
191+ topDocsLists .add (new ArrayList <>());
192+ nestedPath .add (new SetOnce <>());
193+ }
194+
195+ for (DfsSearchResult dfsSearchResult : dfsSearchResults ) {
196+ if (dfsSearchResult .knnResults () != null ) {
197+ for (int i = 0 ; i < dfsSearchResult .knnResults ().size (); i ++) {
198+ DfsKnnResults knnResults = dfsSearchResult .knnResults ().get (i );
199+ ScoreDoc [] scoreDocs = knnResults .scoreDocs ();
200+ TotalHits totalHits = new TotalHits (scoreDocs .length , TotalHits .Relation .EQUAL_TO );
201+ TopDocs shardTopDocs = new TopDocs (totalHits , scoreDocs );
202+ SearchPhaseController .setShardIndex (shardTopDocs , dfsSearchResult .getShardIndex ());
203+ topDocsLists .get (i ).add (shardTopDocs );
204+ nestedPath .get (i ).trySet (knnResults .getNestedPath ());
205+ }
206+ }
207+ }
208+
209+ List <DfsKnnResults > mergedResults = new ArrayList <>(source .knnSearch ().size ());
210+ for (int i = 0 ; i < source .knnSearch ().size (); i ++) {
211+ TopDocs mergedTopDocs = TopDocs .merge (source .knnSearch ().get (i ).k (), topDocsLists .get (i ).toArray (new TopDocs [0 ]));
212+ mergedResults .add (new DfsKnnResults (nestedPath .get (i ).get (), mergedTopDocs .scoreDocs ));
213+ }
214+ return mergedResults ;
215+ }
216+
217+ private static AggregatedDfs aggregateDfs (Collection <DfsSearchResult > results ) {
218+ Map <Term , TermStatistics > termStatistics = new HashMap <>();
219+ Map <String , CollectionStatistics > fieldStatistics = new HashMap <>();
220+ long aggMaxDoc = 0 ;
221+ for (DfsSearchResult lEntry : results ) {
222+ final Term [] terms = lEntry .terms ();
223+ final TermStatistics [] stats = lEntry .termStatistics ();
224+ assert terms .length == stats .length ;
225+ for (int i = 0 ; i < terms .length ; i ++) {
226+ assert terms [i ] != null ;
227+ if (stats [i ] == null ) {
228+ continue ;
229+ }
230+ TermStatistics existing = termStatistics .get (terms [i ]);
231+ if (existing != null ) {
232+ assert terms [i ].bytes ().equals (existing .term ());
233+ termStatistics .put (
234+ terms [i ],
235+ new TermStatistics (
236+ existing .term (),
237+ existing .docFreq () + stats [i ].docFreq (),
238+ existing .totalTermFreq () + stats [i ].totalTermFreq ()
239+ )
240+ );
241+ } else {
242+ termStatistics .put (terms [i ], stats [i ]);
243+ }
244+
245+ }
246+
247+ assert lEntry .fieldStatistics ().containsKey (null ) == false ;
248+ for (var entry : lEntry .fieldStatistics ().entrySet ()) {
249+ String key = entry .getKey ();
250+ CollectionStatistics value = entry .getValue ();
251+ if (value == null ) {
252+ continue ;
253+ }
254+ assert key != null ;
255+ CollectionStatistics existing = fieldStatistics .get (key );
256+ if (existing != null ) {
257+ CollectionStatistics merged = new CollectionStatistics (
258+ key ,
259+ existing .maxDoc () + value .maxDoc (),
260+ existing .docCount () + value .docCount (),
261+ existing .sumTotalTermFreq () + value .sumTotalTermFreq (),
262+ existing .sumDocFreq () + value .sumDocFreq ()
263+ );
264+ fieldStatistics .put (key , merged );
265+ } else {
266+ fieldStatistics .put (key , value );
267+ }
268+ }
269+ aggMaxDoc += lEntry .maxDoc ();
270+ }
271+ return new AggregatedDfs (termStatistics , fieldStatistics , aggMaxDoc );
272+ }
180273}
0 commit comments