1414import org .apache .lucene .search .LeafCollector ;
1515import org .apache .lucene .search .Query ;
1616import org .apache .lucene .search .ScoreDoc ;
17- import org .apache .lucene .search .ScoreMode ;
1817import org .apache .lucene .search .Sort ;
1918import org .apache .lucene .search .SortField ;
2019import org .apache .lucene .search .TopDocsCollector ;
2120import org .apache .lucene .search .TopFieldCollectorManager ;
2221import org .apache .lucene .search .TopScoreDocCollectorManager ;
22+ import org .apache .lucene .search .Weight ;
2323import org .elasticsearch .common .Strings ;
2424import org .elasticsearch .compute .data .BlockFactory ;
2525import org .elasticsearch .compute .data .DocBlock ;
3636import org .elasticsearch .search .sort .SortBuilder ;
3737
3838import java .io .IOException ;
39+ import java .io .UncheckedIOException ;
3940import java .util .ArrayList ;
4041import java .util .Arrays ;
4142import java .util .List ;
4243import java .util .Optional ;
4344import java .util .function .Function ;
4445import java .util .stream .Collectors ;
4546
46- import static org .apache .lucene .search .ScoreMode .TOP_DOCS ;
47- import static org .apache .lucene .search .ScoreMode .TOP_DOCS_WITH_SCORES ;
48-
4947/**
5048 * Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
5149 */
@@ -63,16 +61,16 @@ public Factory(
6361 int maxPageSize ,
6462 int limit ,
6563 List <SortBuilder <?>> sorts ,
66- boolean scoring
64+ boolean needsScore
6765 ) {
68- super (contexts , queryFunction , dataPartitioning , taskConcurrency , limit , scoring ? TOP_DOCS_WITH_SCORES : TOP_DOCS );
66+ super (contexts , weightFunction ( queryFunction , sorts ), dataPartitioning , taskConcurrency , limit , needsScore );
6967 this .maxPageSize = maxPageSize ;
7068 this .sorts = sorts ;
7169 }
7270
7371 @ Override
7472 public SourceOperator get (DriverContext driverContext ) {
75- return new LuceneTopNSourceOperator (driverContext .blockFactory (), maxPageSize , sorts , limit , sliceQueue , scoreMode );
73+ return new LuceneTopNSourceOperator (driverContext .blockFactory (), maxPageSize , sorts , limit , sliceQueue , needsScore );
7674 }
7775
7876 public int maxPageSize () {
@@ -88,8 +86,8 @@ public String describe() {
8886 + maxPageSize
8987 + ", limit = "
9088 + limit
91- + ", scoreMode = "
92- + scoreMode
89+ + ", needsScore = "
90+ + needsScore
9391 + ", sorts = ["
9492 + notPrettySorts
9593 + "]]" ;
@@ -108,20 +106,20 @@ public String describe() {
108106 private PerShardCollector perShardCollector ;
109107 private final List <SortBuilder <?>> sorts ;
110108 private final int limit ;
111- private final ScoreMode scoreMode ;
109+ private final boolean needsScore ;
112110
113111 public LuceneTopNSourceOperator (
114112 BlockFactory blockFactory ,
115113 int maxPageSize ,
116114 List <SortBuilder <?>> sorts ,
117115 int limit ,
118116 LuceneSliceQueue sliceQueue ,
119- ScoreMode scoreMode
117+ boolean needsScore
120118 ) {
121119 super (blockFactory , maxPageSize , sliceQueue );
122120 this .sorts = sorts ;
123121 this .limit = limit ;
124- this .scoreMode = scoreMode ;
122+ this .needsScore = needsScore ;
125123 }
126124
127125 @ Override
@@ -163,7 +161,7 @@ private Page collect() throws IOException {
163161 try {
164162 if (perShardCollector == null || perShardCollector .shardContext .index () != scorer .shardContext ().index ()) {
165163 // TODO: share the bottom between shardCollectors
166- perShardCollector = newPerShardCollector (scorer .shardContext (), sorts , limit );
164+ perShardCollector = newPerShardCollector (scorer .shardContext (), sorts , needsScore , limit );
167165 }
168166 var leafCollector = perShardCollector .getLeafCollector (scorer .leafReaderContext ());
169167 scorer .scoreNextRange (leafCollector , scorer .leafReaderContext ().reader ().getLiveDocs (), maxPageSize );
@@ -261,7 +259,7 @@ private float getScore(ScoreDoc scoreDoc) {
261259 }
262260
263261 private DoubleVector .Builder scoreVectorOrNull (int size ) {
264- if (scoreMode . needsScores () ) {
262+ if (needsScore ) {
265263 return blockFactory .newDoubleVectorFixedBuilder (size );
266264 } else {
267265 return null ;
@@ -271,37 +269,11 @@ private DoubleVector.Builder scoreVectorOrNull(int size) {
271269 @ Override
272270 protected void describe (StringBuilder sb ) {
273271 sb .append (", limit = " ).append (limit );
274- sb .append (", scoreMode = " ).append (scoreMode );
272+ sb .append (", needsScore = " ).append (needsScore );
275273 String notPrettySorts = sorts .stream ().map (Strings ::toString ).collect (Collectors .joining ("," ));
276274 sb .append (", sorts = [" ).append (notPrettySorts ).append ("]" );
277275 }
278276
279- PerShardCollector newPerShardCollector (ShardContext shardContext , List <SortBuilder <?>> sorts , int limit ) throws IOException {
280- Optional <SortAndFormats > sortAndFormats = shardContext .buildSort (sorts );
281- if (sortAndFormats .isEmpty ()) {
282- throw new IllegalStateException ("sorts must not be disabled in TopN" );
283- }
284- if (scoreMode .needsScores () == false ) {
285- return new NonScoringPerShardCollector (shardContext , sortAndFormats .get ().sort , limit );
286- } else {
287- SortField [] sortFields = sortAndFormats .get ().sort .getSort ();
288- if (sortFields != null && sortFields .length == 1 && sortFields [0 ].needsScores () && sortFields [0 ].getReverse () == false ) {
289- // SORT _score DESC
290- return new ScoringPerShardCollector (shardContext , new TopScoreDocCollectorManager (limit , null , 0 ).newCollector ());
291- } else {
292- // SORT ..., _score, ...
293- var sort = new Sort ();
294- if (sortFields != null ) {
295- var l = new ArrayList <>(Arrays .asList (sortFields ));
296- l .add (SortField .FIELD_DOC );
297- l .add (SortField .FIELD_SCORE );
298- sort = new Sort (l .toArray (SortField []::new ));
299- }
300- return new ScoringPerShardCollector (shardContext , new TopFieldCollectorManager (sort , limit , null , 0 ).newCollector ());
301- }
302- }
303- }
304-
305277 abstract static class PerShardCollector {
306278 private final ShardContext shardContext ;
307279 private final TopDocsCollector <?> collector ;
@@ -336,4 +308,44 @@ static final class ScoringPerShardCollector extends PerShardCollector {
336308 super (shardContext , topDocsCollector );
337309 }
338310 }
311+
312+ private static Function <ShardContext , Weight > weightFunction (Function <ShardContext , Query > queryFunction , List <SortBuilder <?>> sorts ) {
313+ return ctx -> {
314+ final var query = queryFunction .apply (ctx );
315+ final var searcher = ctx .searcher ();
316+ try {
317+ // we create a collector with a limit of 1 to determine the appropriate score mode to use.
318+ var scoreMode = newPerShardCollector (ctx , sorts , false , 1 ).collector .scoreMode ();
319+ return searcher .createWeight (searcher .rewrite (query ), scoreMode , 1 );
320+ } catch (IOException e ) {
321+ throw new UncheckedIOException (e );
322+ }
323+ };
324+ }
325+
326+ private static PerShardCollector newPerShardCollector (ShardContext context , List <SortBuilder <?>> sorts , boolean needsScore , int limit )
327+ throws IOException {
328+ Optional <SortAndFormats > sortAndFormats = context .buildSort (sorts );
329+ if (sortAndFormats .isEmpty ()) {
330+ throw new IllegalStateException ("sorts must not be disabled in TopN" );
331+ }
332+ if (needsScore == false ) {
333+ return new NonScoringPerShardCollector (context , sortAndFormats .get ().sort , limit );
334+ }
335+ SortField [] sortFields = sortAndFormats .get ().sort .getSort ();
336+ if (sortFields != null && sortFields .length == 1 && sortFields [0 ].needsScores () && sortFields [0 ].getReverse () == false ) {
337+ // SORT _score DESC
338+ return new ScoringPerShardCollector (context , new TopScoreDocCollectorManager (limit , null , 0 ).newCollector ());
339+ }
340+
341+ // SORT ..., _score, ...
342+ var sort = new Sort ();
343+ if (sortFields != null ) {
344+ var l = new ArrayList <>(Arrays .asList (sortFields ));
345+ l .add (SortField .FIELD_DOC );
346+ l .add (SortField .FIELD_SCORE );
347+ sort = new Sort (l .toArray (SortField []::new ));
348+ }
349+ return new ScoringPerShardCollector (context , new TopFieldCollectorManager (sort , limit , null , 0 ).newCollector ());
350+ }
339351}
0 commit comments