77
88package org .elasticsearch .compute .lucene ;
99
10- import org .apache .lucene .search . Collector ;
10+ import org .apache .lucene .index . LeafReaderContext ;
1111import org .apache .lucene .search .FieldDoc ;
12+ import org .apache .lucene .search .LeafCollector ;
1213import org .apache .lucene .search .Query ;
14+ import org .apache .lucene .search .Scorable ;
1315import org .apache .lucene .search .ScoreDoc ;
1416import org .apache .lucene .search .ScoreMode ;
1517import org .apache .lucene .search .Sort ;
1618import org .apache .lucene .search .SortField ;
17- import org .apache .lucene .search .TopDocs ;
18- import org .apache .lucene .search .TopFieldCollector ;
19+ import org .apache .lucene .search .TopDocsCollector ;
1920import org .apache .lucene .search .TopFieldCollectorManager ;
21+ import org .apache .lucene .search .TopScoreDocCollectorManager ;
22+ import org .apache .lucene .util .PriorityQueue ;
2023import org .elasticsearch .common .Strings ;
2124import org .elasticsearch .compute .data .Block ;
2225import org .elasticsearch .compute .data .BlockFactory ;
@@ -103,15 +106,18 @@ protected Page maybeAppendScore(Page page, DoubleVector.Builder currentScoresBui
103106 }
104107
105108 float getScore (ScoreDoc scoreDoc ) {
106- FieldDoc fieldDoc = (FieldDoc ) scoreDoc ;
107- if (Float .isNaN (fieldDoc .score )) {
108- if (sorts != null ) {
109- return (Float ) fieldDoc .fields [sorts .size ()];
109+ if (scoreDoc instanceof FieldDoc fieldDoc ) {
110+ if (Float .isNaN (fieldDoc .score )) {
111+ if (sorts != null ) {
112+ return (Float ) fieldDoc .fields [sorts .size ()];
113+ } else {
114+ return (Float ) fieldDoc .fields [0 ];
115+ }
110116 } else {
111- return ( Float ) fieldDoc .fields [ 0 ] ;
117+ return fieldDoc .score ;
112118 }
113119 } else {
114- return fieldDoc .score ;
120+ return scoreDoc .score ;
115121 }
116122 }
117123
@@ -124,32 +130,62 @@ PerShardCollector newPerShardCollector(ShardContext shardContext, List<SortBuild
124130 l .add (SortField .FIELD_SCORE );
125131 sort = new Sort (l .toArray (SortField []::new ));
126132 } else {
127- sort = new Sort () ;
133+ sort = null ;
128134 }
129135 return new ScoringPerShardCollector (shardContext , sort , limit );
130136 }
131137
132138 static class ScoringPerShardCollector extends PerShardCollector {
133139
134140 // TODO : make this configurable / inferrable?
135- private final TopFieldCollector collector ;
136141 private static final int MAX_HITS = 100_000 ;
137142 private static final int TOTAL_HITS_THRESHOLD = 100 ;
138143
139144 ScoringPerShardCollector (ShardContext shardContext , Sort sort , int limit ) {
140145 this .shardContext = shardContext ;
141- this .collector = new TopFieldCollectorManager (sort , Math .min (limit , MAX_HITS ), TOTAL_HITS_THRESHOLD ).newCollector ();
142- // TODO : use TopScoreDocCollectorManager when SORT _score DESC
146+ if (sort == null ) {
147+ this .collector = new UnsortedScoreCollector (new PriorityQueue <>(Math .min (limit , MAX_HITS )) {
148+ @ Override
149+ protected boolean lessThan (ScoreDoc a , ScoreDoc b ) {
150+ return a .doc > b .doc ;
151+ }
152+ });
153+ } else if (sort .needsScores ()) {
154+ this .collector = new TopScoreDocCollectorManager (Math .min (limit , MAX_HITS ), TOTAL_HITS_THRESHOLD ).newCollector ();
155+ } else {
156+ this .collector = new TopFieldCollectorManager (sort , Math .min (limit , MAX_HITS ), TOTAL_HITS_THRESHOLD ).newCollector ();
157+ }
158+ }
159+ }
160+
161+ private static class UnsortedScoreCollector extends TopDocsCollector <ScoreDoc > {
162+
163+ protected UnsortedScoreCollector (PriorityQueue <ScoreDoc > pq ) {
164+ super (pq );
143165 }
144166
145167 @ Override
146- Collector getCollector () {
147- return collector ;
168+ public LeafCollector getLeafCollector (LeafReaderContext context ) throws IOException {
169+ return new LeafCollector () {
170+ private Scorable scorable ;
171+
172+ @ Override
173+ public void setScorer (Scorable scorable ) {
174+ this .scorable = scorable ;
175+ }
176+
177+ @ Override
178+ public void collect (int docID ) throws IOException {
179+ float score = scorable .score ();
180+ pq .add (new ScoreDoc (docID , score ));
181+ totalHits ++;
182+ }
183+ };
148184 }
149185
150186 @ Override
151- TopDocs getTopDocs () {
152- return collector . topDocs () ;
187+ public ScoreMode scoreMode () {
188+ return ScoreMode . COMPLETE ;
153189 }
154190 }
155191}
0 commit comments