1212import org .apache .lucene .search .Scorable ;
1313import org .apache .lucene .search .ScoreMode ;
1414import org .apache .lucene .search .Weight ;
15+ import org .elasticsearch .compute .data .Block ;
1516import org .elasticsearch .compute .data .BlockFactory ;
16- import org .elasticsearch .compute .data .BooleanBlock ;
17- import org .elasticsearch .compute .data .LongBlock ;
17+ import org .elasticsearch .compute .data .BlockUtils ;
18+ import org .elasticsearch .compute .data .ElementType ;
19+ import org .elasticsearch .compute .data .LongVector ;
1820import org .elasticsearch .compute .data .Page ;
1921import org .elasticsearch .compute .operator .DriverContext ;
2022import org .elasticsearch .compute .operator .SourceOperator ;
2123import org .elasticsearch .core .RefCounted ;
2224import org .elasticsearch .core .Releasables ;
2325
2426import java .io .IOException ;
27+ import java .util .HashMap ;
2528import java .util .List ;
29+ import java .util .Map ;
2630import java .util .function .Function ;
2731
2832/**
3236 * 2. a bool flag (seen) that's always true meaning that the group (all items) always exists
3337 */
3438public class LuceneCountOperator extends LuceneOperator {
35-
36- private static final int PAGE_SIZE = 1 ;
37-
38- private int totalHits = 0 ;
39- private int remainingDocs ;
40-
41- private final LeafCollector leafCollector ;
42-
4339 public static class Factory extends LuceneOperator .Factory {
4440 private final List <? extends RefCounted > shardRefCounters ;
41+ private final List <ElementType > tagTypes ;
4542
4643 public Factory (
4744 List <? extends ShardContext > contexts ,
4845 Function <ShardContext , List <LuceneSliceQueue .QueryAndTags >> queryFunction ,
4946 DataPartitioning dataPartitioning ,
5047 int taskConcurrency ,
48+ List <ElementType > tagTypes ,
5149 int limit
5250 ) {
5351 super (
@@ -61,11 +59,12 @@ public Factory(
6159 shardContext -> ScoreMode .COMPLETE_NO_SCORES
6260 );
6361 this .shardRefCounters = contexts ;
62+ this .tagTypes = tagTypes ;
6463 }
6564
6665 @ Override
6766 public SourceOperator get (DriverContext driverContext ) {
68- return new LuceneCountOperator (shardRefCounters , driverContext .blockFactory (), sliceQueue , limit );
67+ return new LuceneCountOperator (shardRefCounters , driverContext .blockFactory (), sliceQueue , tagTypes , limit );
6968 }
7069
7170 @ Override
@@ -74,35 +73,20 @@ public String describe() {
7473 }
7574 }
7675
76+ private final List <ElementType > tagTypes ;
77+ private final Map <List <Object >, PerTagsState > tagsToState = new HashMap <>();
78+ private int remainingDocs ;
79+
7780 public LuceneCountOperator (
7881 List <? extends RefCounted > shardRefCounters ,
7982 BlockFactory blockFactory ,
8083 LuceneSliceQueue sliceQueue ,
84+ List <ElementType > tagTypes ,
8185 int limit
8286 ) {
83- super (shardRefCounters , blockFactory , PAGE_SIZE , sliceQueue );
87+ super (shardRefCounters , blockFactory , Integer .MAX_VALUE , sliceQueue );
88+ this .tagTypes = tagTypes ;
8489 this .remainingDocs = limit ;
85- this .leafCollector = new LeafCollector () {
86- @ Override
87- public void setScorer (Scorable scorer ) {}
88-
89- @ Override
90- public void collect (DocIdStream stream ) throws IOException {
91- if (remainingDocs > 0 ) {
92- int count = Math .min (stream .count (), remainingDocs );
93- totalHits += count ;
94- remainingDocs -= count ;
95- }
96- }
97-
98- @ Override
99- public void collect (int doc ) {
100- if (remainingDocs > 0 ) {
101- remainingDocs --;
102- totalHits ++;
103- }
104- }
105- };
10690 }
10791
10892 @ Override
@@ -124,59 +108,133 @@ protected Page getCheckedOutput() throws IOException {
124108 long start = System .nanoTime ();
125109 try {
126110 final LuceneScorer scorer = getCurrentOrLoadNextScorer ();
127- // no scorer means no more docs
128111 if (scorer == null ) {
129112 remainingDocs = 0 ;
130113 } else {
131- if (scorer .tags ().isEmpty () == false ) {
132- throw new UnsupportedOperationException ("tags not supported by " + getClass ());
133- }
134- Weight weight = scorer .weight ();
135- var leafReaderContext = scorer .leafReaderContext ();
136- // see org.apache.lucene.search.TotalHitCountCollector
137- int leafCount = weight .count (leafReaderContext );
138- if (leafCount != -1 ) {
139- // make sure to NOT multi count as the count _shortcut_ (which is segment wide)
140- // handle doc partitioning where the same leaf can be seen multiple times
141- // since the count is global, consider it only for the first partition and skip the rest
142- // SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0)
143- if (scorer .position () == 0 ) {
144- // check to not count over the desired number of docs/limit
145- var count = Math .min (leafCount , remainingDocs );
146- totalHits += count ;
147- remainingDocs -= count ;
148- }
149- scorer .markAsDone ();
150- } else {
151- // could not apply shortcut, trigger the search
152- // TODO: avoid iterating all documents in multiple calls to make cancellation more responsive.
153- scorer .scoreNextRange (leafCollector , leafReaderContext .reader ().getLiveDocs (), remainingDocs );
154- }
114+ count (scorer );
115+ }
116+
117+ if (remainingDocs <= 0 ) {
118+ return buildResult ();
119+ }
120+ return null ;
121+ } finally {
122+ processingNanos += System .nanoTime () - start ;
123+ }
124+ }
125+
126+ private void count (LuceneScorer scorer ) throws IOException {
127+ PerTagsState state = tagsToState .computeIfAbsent (scorer .tags (), t -> new PerTagsState ());
128+ Weight weight = scorer .weight ();
129+ var leafReaderContext = scorer .leafReaderContext ();
130+ // see org.apache.lucene.search.TotalHitCountCollector
131+ int leafCount = weight .count (leafReaderContext );
132+ if (leafCount != -1 ) {
133+ // make sure to NOT multi count as the count _shortcut_ (which is segment wide)
134+ // handle doc partitioning where the same leaf can be seen multiple times
135+ // since the count is global, consider it only for the first partition and skip the rest
136+ // SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0)
137+ if (scorer .position () == 0 ) {
138+ // check to not count over the desired number of docs/limit
139+ var count = Math .min (leafCount , remainingDocs );
140+ state .totalHits += count ;
141+ remainingDocs -= count ;
155142 }
143+ scorer .markAsDone ();
144+ } else {
145+ // could not apply shortcut, trigger the search
146+ // TODO: avoid iterating all documents in multiple calls to make cancellation more responsive.
147+ scorer .scoreNextRange (state , leafReaderContext .reader ().getLiveDocs (), remainingDocs );
148+ }
149+ }
156150
157- Page page = null ;
158- // emit only one page
159- if (remainingDocs <= 0 && pagesEmitted == 0 ) {
160- LongBlock count = null ;
161- BooleanBlock seen = null ;
162- try {
163- count = blockFactory .newConstantLongBlockWith (totalHits , PAGE_SIZE );
164- seen = blockFactory .newConstantBooleanBlockWith (true , PAGE_SIZE );
165- page = new Page (PAGE_SIZE , count , seen );
166- } finally {
167- if (page == null ) {
168- Releasables .closeExpectNoException (count , seen );
169- }
151+ private Page buildResult () {
152+ return switch (tagsToState .size ()) {
153+ case 0 -> null ;
154+ case 1 -> {
155+ Map .Entry <List <Object >, PerTagsState > e = tagsToState .entrySet ().iterator ().next ();
156+ yield buildConstantBlocksResult (e .getKey (), e .getValue ());
157+ }
158+ default -> buildNonConstantBlocksResult ();
159+ };
160+ }
161+
162+ private Page buildConstantBlocksResult (List <Object > tags , PerTagsState state ) {
163+ Block [] blocks = new Block [2 + tagTypes .size ()];
164+ int b = 0 ;
165+ try {
166+ blocks [b ++] = blockFactory .newConstantLongBlockWith (state .totalHits , 1 );
167+ blocks [b ++] = blockFactory .newConstantBooleanBlockWith (true , 1 );
168+ for (Object e : tags ) {
169+ blocks [b ++] = BlockUtils .constantBlock (blockFactory , e , 1 );
170+ }
171+ Page page = new Page (1 , blocks );
172+ blocks = null ;
173+ return page ;
174+ } finally {
175+ if (blocks != null ) {
176+ Releasables .closeExpectNoException (blocks );
177+ }
178+ }
179+ }
180+
181+ private Page buildNonConstantBlocksResult () {
182+ BlockUtils .BuilderWrapper [] builders = new BlockUtils .BuilderWrapper [tagTypes .size ()];
183+ Block [] blocks = new Block [2 + tagTypes .size ()];
184+ try (LongVector .Builder countBuilder = blockFactory .newLongVectorBuilder (tagsToState .size ())) {
185+ int b = 0 ;
186+ for (ElementType t : tagTypes ) {
187+ builders [b ++] = BlockUtils .wrapperFor (blockFactory , t , tagsToState .size ());
188+ }
189+
190+ for (Map .Entry <List <Object >, PerTagsState > e : tagsToState .entrySet ()) {
191+ countBuilder .appendLong (e .getValue ().totalHits );
192+ b = 0 ;
193+ for (Object t : e .getKey ()) {
194+ builders [b ++].accept (t );
170195 }
171196 }
197+
198+ blocks [0 ] = countBuilder .build ().asBlock ();
199+ blocks [1 ] = blockFactory .newConstantBooleanBlockWith (true , tagsToState .size ());
200+ for (b = 0 ; b < builders .length ; b ++) {
201+ blocks [2 + b ] = builders [b ].builder ().build ();
202+ builders [b ++] = null ;
203+ }
204+ Page page = new Page (tagsToState .size (), blocks );
205+ blocks = null ;
172206 return page ;
173207 } finally {
174- processingNanos += System . nanoTime () - start ;
208+ Releasables . closeExpectNoException ( Releasables . wrap ( builders ), blocks == null ? () -> {} : Releasables . wrap ( blocks )) ;
175209 }
176210 }
177211
178212 @ Override
179213 protected void describe (StringBuilder sb ) {
180214 sb .append (", remainingDocs=" ).append (remainingDocs );
181215 }
216+
217+ private class PerTagsState implements LeafCollector {
218+ long totalHits ;
219+
220+ @ Override
221+ public void setScorer (Scorable scorer ) {}
222+
223+ @ Override
224+ public void collect (DocIdStream stream ) throws IOException {
225+ if (remainingDocs > 0 ) {
226+ int count = Math .min (stream .count (), remainingDocs );
227+ totalHits += count ;
228+ remainingDocs -= count ;
229+ }
230+ }
231+
232+ @ Override
233+ public void collect (int doc ) {
234+ if (remainingDocs > 0 ) {
235+ remainingDocs --;
236+ totalHits ++;
237+ }
238+ }
239+ }
182240}
0 commit comments