diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/BlockUtils.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/BlockUtils.java index 657f7b8504c94..6fb0b844e4dd3 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/BlockUtils.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/BlockUtils.java @@ -174,6 +174,7 @@ public static Block[] fromList(BlockFactory blockFactory, List> lis /** Returns a deep copy of the given block, using the blockFactory for creating the copy block. */ public static Block deepCopyOf(Block block, BlockFactory blockFactory) { + // TODO preserve constants here. try (Block.Builder builder = block.elementType().newBlockBuilder(block.getPositionCount(), blockFactory)) { builder.copyFrom(block, 0, block.getPositionCount()); builder.mvOrdering(block.mvOrdering()); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java index cded3a3494738..de6f4c9a7dbdb 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java @@ -12,9 +12,11 @@ import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; +import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BooleanBlock; -import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.SourceOperator; @@ -22,7 +24,9 @@ import org.elasticsearch.core.Releasables; import java.io.IOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.function.Function; /** @@ -32,22 +36,16 @@ * 2. a bool flag (seen) that's always true meaning that the group (all items) always exists */ public class LuceneCountOperator extends LuceneOperator { - - private static final int PAGE_SIZE = 1; - - private int totalHits = 0; - private int remainingDocs; - - private final LeafCollector leafCollector; - public static class Factory extends LuceneOperator.Factory { private final List shardRefCounters; + private final List tagTypes; public Factory( List contexts, Function> queryFunction, DataPartitioning dataPartitioning, int taskConcurrency, + List tagTypes, int limit ) { super( @@ -61,11 +59,12 @@ public Factory( shardContext -> ScoreMode.COMPLETE_NO_SCORES ); this.shardRefCounters = contexts; + this.tagTypes = tagTypes; } @Override public SourceOperator get(DriverContext driverContext) { - return new LuceneCountOperator(shardRefCounters, driverContext.blockFactory(), sliceQueue, limit); + return new LuceneCountOperator(shardRefCounters, driverContext.blockFactory(), sliceQueue, tagTypes, limit); } @Override @@ -74,35 +73,20 @@ public String describe() { } } + private final List tagTypes; + private final Map, PerTagsState> tagsToState = new HashMap<>(); + private int remainingDocs; + public LuceneCountOperator( List shardRefCounters, BlockFactory blockFactory, LuceneSliceQueue sliceQueue, + List tagTypes, int limit ) { - super(shardRefCounters, blockFactory, PAGE_SIZE, sliceQueue); + super(shardRefCounters, blockFactory, Integer.MAX_VALUE, sliceQueue); + this.tagTypes = tagTypes; this.remainingDocs = limit; - this.leafCollector = new LeafCollector() { - @Override - public void setScorer(Scorable scorer) {} - - @Override - public void collect(DocIdStream stream) throws IOException { - if (remainingDocs > 0) { - int count = Math.min(stream.count(), remainingDocs); - totalHits += count; - remainingDocs -= count; - } - } - - @Override - public void collect(int doc) { - if (remainingDocs > 0) { - remainingDocs--; - totalHits++; - } - } - }; } @Override @@ -124,54 +108,104 @@ protected Page getCheckedOutput() throws IOException { long start = System.nanoTime(); try { final LuceneScorer scorer = getCurrentOrLoadNextScorer(); - // no scorer means no more docs if (scorer == null) { remainingDocs = 0; } else { - if (scorer.tags().isEmpty() == false) { - throw new UnsupportedOperationException("tags not supported by " + getClass()); - } - Weight weight = scorer.weight(); - var leafReaderContext = scorer.leafReaderContext(); - // see org.apache.lucene.search.TotalHitCountCollector - int leafCount = weight.count(leafReaderContext); - if (leafCount != -1) { - // make sure to NOT multi count as the count _shortcut_ (which is segment wide) - // handle doc partitioning where the same leaf can be seen multiple times - // since the count is global, consider it only for the first partition and skip the rest - // SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0) - if (scorer.position() == 0) { - // check to not count over the desired number of docs/limit - var count = Math.min(leafCount, remainingDocs); - totalHits += count; - remainingDocs -= count; - } - scorer.markAsDone(); - } else { - // could not apply shortcut, trigger the search - // TODO: avoid iterating all documents in multiple calls to make cancellation more responsive. - scorer.scoreNextRange(leafCollector, leafReaderContext.reader().getLiveDocs(), remainingDocs); - } + count(scorer); + } + + if (remainingDocs <= 0) { + return buildResult(); + } + return null; + } finally { + processingNanos += System.nanoTime() - start; + } + } + + private void count(LuceneScorer scorer) throws IOException { + PerTagsState state = tagsToState.computeIfAbsent(scorer.tags(), t -> new PerTagsState()); + Weight weight = scorer.weight(); + var leafReaderContext = scorer.leafReaderContext(); + // see org.apache.lucene.search.TotalHitCountCollector + int leafCount = weight.count(leafReaderContext); + if (leafCount != -1) { + // make sure to NOT multi count as the count _shortcut_ (which is segment wide) + // handle doc partitioning where the same leaf can be seen multiple times + // since the count is global, consider it only for the first partition and skip the rest + // SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0) + if (scorer.position() == 0) { + // check to not count over the desired number of docs/limit + var count = Math.min(leafCount, remainingDocs); + state.totalHits += count; + remainingDocs -= count; } + scorer.markAsDone(); + } else { + // could not apply shortcut, trigger the search + // TODO: avoid iterating all documents in multiple calls to make cancellation more responsive. + scorer.scoreNextRange(state, leafReaderContext.reader().getLiveDocs(), remainingDocs); + } + } - Page page = null; - // emit only one page - if (remainingDocs <= 0 && pagesEmitted == 0) { - LongBlock count = null; - BooleanBlock seen = null; - try { - count = blockFactory.newConstantLongBlockWith(totalHits, PAGE_SIZE); - seen = blockFactory.newConstantBooleanBlockWith(true, PAGE_SIZE); - page = new Page(PAGE_SIZE, count, seen); - } finally { - if (page == null) { - Releasables.closeExpectNoException(count, seen); - } + private Page buildResult() { + return switch (tagsToState.size()) { + case 0 -> null; + case 1 -> { + Map.Entry, PerTagsState> e = tagsToState.entrySet().iterator().next(); + yield buildConstantBlocksResult(e.getKey(), e.getValue()); + } + default -> buildNonConstantBlocksResult(); + }; + } + + private Page buildConstantBlocksResult(List tags, PerTagsState state) { + Block[] blocks = new Block[2 + tagTypes.size()]; + int b = 0; + try { + blocks[b++] = blockFactory.newConstantLongBlockWith(state.totalHits, 1); + blocks[b++] = blockFactory.newConstantBooleanBlockWith(true, 1); + for (Object e : tags) { + blocks[b++] = BlockUtils.constantBlock(blockFactory, e, 1); + } + Page page = new Page(1, blocks); + blocks = null; + return page; + } finally { + if (blocks != null) { + Releasables.closeExpectNoException(blocks); + } + } + } + + private Page buildNonConstantBlocksResult() { + BlockUtils.BuilderWrapper[] builders = new BlockUtils.BuilderWrapper[tagTypes.size()]; + Block[] blocks = new Block[2 + tagTypes.size()]; + try (LongVector.Builder countBuilder = blockFactory.newLongVectorBuilder(tagsToState.size())) { + int b = 0; + for (ElementType t : tagTypes) { + builders[b++] = BlockUtils.wrapperFor(blockFactory, t, tagsToState.size()); + } + + for (Map.Entry, PerTagsState> e : tagsToState.entrySet()) { + countBuilder.appendLong(e.getValue().totalHits); + b = 0; + for (Object t : e.getKey()) { + builders[b++].accept(t); } } + + blocks[0] = countBuilder.build().asBlock(); + blocks[1] = blockFactory.newConstantBooleanBlockWith(true, tagsToState.size()); + for (b = 0; b < builders.length; b++) { + blocks[2 + b] = builders[b].builder().build(); + builders[b++] = null; + } + Page page = new Page(tagsToState.size(), blocks); + blocks = null; return page; } finally { - processingNanos += System.nanoTime() - start; + Releasables.closeExpectNoException(Releasables.wrap(builders), blocks == null ? () -> {} : Releasables.wrap(blocks)); } } @@ -179,4 +213,28 @@ protected Page getCheckedOutput() throws IOException { protected void describe(StringBuilder sb) { sb.append(", remainingDocs=").append(remainingDocs); } + + private class PerTagsState implements LeafCollector { + long totalHits; + + @Override + public void setScorer(Scorable scorer) {} + + @Override + public void collect(DocIdStream stream) throws IOException { + if (remainingDocs > 0) { + int count = Math.min(stream.count(), remainingDocs); + totalHits += count; + remainingDocs -= count; + } + } + + @Override + public void collect(int doc) { + if (remainingDocs > 0) { + remainingDocs--; + totalHits++; + } + } + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index d8f9787c2864c..56acaa86c299d 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -31,6 +31,9 @@ import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.compute.aggregation.CountAggregatorFunction; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DocBlock; @@ -41,6 +44,7 @@ import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.DataPartitioning; +import org.elasticsearch.compute.lucene.LuceneCountOperator; import org.elasticsearch.compute.lucene.LuceneOperator; import org.elasticsearch.compute.lucene.LuceneSliceQueue; import org.elasticsearch.compute.lucene.LuceneSourceOperator; @@ -49,9 +53,11 @@ import org.elasticsearch.compute.lucene.read.ValuesSourceReaderOperator; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.RowInTableLookupOperator; import org.elasticsearch.compute.test.BlockTestUtils; +import org.elasticsearch.compute.test.CannedSourceOperator; import org.elasticsearch.compute.test.OperatorTestCase; import org.elasticsearch.compute.test.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.test.TestDriverFactory; @@ -79,8 +85,11 @@ import java.util.TreeMap; import static org.elasticsearch.compute.test.OperatorTestCase.randomPageSize; +import static org.elasticsearch.test.MapMatcher.assertMap; +import static org.elasticsearch.test.MapMatcher.matchesMap; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.lessThanOrEqualTo; /** @@ -360,6 +369,104 @@ public void testHashLookup() { } } + public void testPushRoundToCountToQuery() throws IOException { + int firstGroupDocs = randomIntBetween(0, 10_000); + int secondGroupDocs = randomIntBetween(0, 10_000); + int thirdGroupDocs = randomIntBetween(0, 10_000); + + CheckedConsumer verifier = reader -> { + Query firstGroupQuery = LongPoint.newRangeQuery("g", Long.MIN_VALUE, 99); + Query secondGroupQuery = LongPoint.newRangeQuery("g", 100, 9999); + Query thirdGroupQuery = LongPoint.newRangeQuery("g", 10000, Long.MAX_VALUE); + + LuceneSliceQueue.QueryAndTags firstGroupQueryAndTags = new LuceneSliceQueue.QueryAndTags(firstGroupQuery, List.of(0L)); + LuceneSliceQueue.QueryAndTags secondGroupQueryAndTags = new LuceneSliceQueue.QueryAndTags(secondGroupQuery, List.of(100L)); + LuceneSliceQueue.QueryAndTags thirdGroupQueryAndTags = new LuceneSliceQueue.QueryAndTags(thirdGroupQuery, List.of(10000L)); + + // Data driver + List dataDriverPages = new ArrayList<>(); + { + LuceneOperator.Factory factory = luceneCountOperatorFactory( + reader, + List.of(ElementType.LONG), + List.of(firstGroupQueryAndTags, secondGroupQueryAndTags, thirdGroupQueryAndTags) + ); + DriverContext driverContext = driverContext(); + try ( + Driver driver = TestDriverFactory.create( + driverContext, + factory.get(driverContext), + List.of(), + new TestResultPageSinkOperator(dataDriverPages::add) + ) + ) { + OperatorTestCase.runDriver(driver); + } + assertDriverContext(driverContext); + } + + // Reduce driver + List reduceDriverPages = new ArrayList<>(); + try (CannedSourceOperator sourceOperator = new CannedSourceOperator(dataDriverPages.iterator())) { + HashAggregationOperator.HashAggregationOperatorFactory aggFactory = + new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(new BlockHash.GroupSpec(2, ElementType.LONG)), + AggregatorMode.FINAL, + List.of(CountAggregatorFunction.supplier().groupingAggregatorFactory(AggregatorMode.FINAL, List.of(0, 1))), + Integer.MAX_VALUE, + null + ); + DriverContext driverContext = driverContext(); + try ( + Driver driver = TestDriverFactory.create( + driverContext, + sourceOperator, + List.of(aggFactory.get(driverContext)), + new TestResultPageSinkOperator(reduceDriverPages::add) + ) + ) { + OperatorTestCase.runDriver(driver); + } + assertDriverContext(driverContext); + } + + assertThat(reduceDriverPages, hasSize(1)); + Page result = reduceDriverPages.getFirst(); + assertThat(result.getBlockCount(), equalTo(2)); + LongBlock groupsBlock = result.getBlock(0); + LongVector groups = groupsBlock.asVector(); + LongBlock countsBlock = result.getBlock(1); + LongVector counts = countsBlock.asVector(); + Map actual = new TreeMap<>(); + for (int p = 0; p < result.getPositionCount(); p++) { + actual.put(groups.getLong(p), counts.getLong(p)); + } + assertMap( + actual, + matchesMap().entry(0L, (long) firstGroupDocs).entry(100L, (long) secondGroupDocs).entry(10000L, (long) thirdGroupDocs) + ); + }; + + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { + for (int i = 0; i < firstGroupDocs; i++) { + long g = randomLongBetween(Long.MIN_VALUE, 99); + w.addDocument(List.of(new LongField("g", g, Field.Store.NO))); + } + for (int i = 0; i < secondGroupDocs; i++) { + long g = randomLongBetween(100, 9999); + w.addDocument(List.of(new LongField("g", g, Field.Store.NO))); + } + for (int i = 0; i < thirdGroupDocs; i++) { + long g = randomLongBetween(10000, Long.MAX_VALUE); + w.addDocument(List.of(new LongField("g", g, Field.Store.NO))); + } + + try (DirectoryReader reader = w.getReader()) { + verifier.accept(reader); + } + } + } + /** * Creates a {@link BigArrays} that tracks releases but doesn't throw circuit breaking exceptions. */ @@ -394,6 +501,22 @@ static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, List tagTypes, + List queryAndTags + ) { + final ShardContext searchContext = new LuceneSourceOperatorTests.MockShardContext(reader, 0); + return new LuceneCountOperator.Factory( + List.of(searchContext), + ctx -> queryAndTags, + randomFrom(DataPartitioning.values()), + randomIntBetween(1, 10), + tagTypes, + LuceneOperator.NO_LIMIT + ); + } + private MappedFieldType.BlockLoaderContext mockBlContext() { return new MappedFieldType.BlockLoaderContext() { @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java index de5de4743f971..1d2798ff132cb 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java @@ -7,17 +7,23 @@ package org.elasticsearch.compute.lucene; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.apache.lucene.document.Document; -import org.apache.lucene.document.LongPoint; +import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; @@ -27,24 +33,179 @@ import org.elasticsearch.compute.test.TestResultPageSinkOperator; import org.elasticsearch.core.IOUtils; import org.elasticsearch.indices.CrankyCircuitBreakerService; +import org.elasticsearch.test.MapMatcher; import org.hamcrest.Matcher; import org.junit.After; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.TreeMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.Function; import java.util.function.Supplier; +import static org.elasticsearch.test.MapMatcher.assertMap; +import static org.elasticsearch.test.MapMatcher.matchesMap; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.matchesRegex; public class LuceneCountOperatorTests extends SourceOperatorTestCase { + @ParametersFactory(argumentFormatting = "%s") + public static Iterable parameters() { + List parameters = new ArrayList<>(); + for (TestCase c : TestCase.values()) { + parameters.add(new Object[] { c }); + } + return parameters; + } + + public enum TestCase { + MATCH_ALL { + @Override + List queryAndExtra() { + return List.of(new LuceneSliceQueue.QueryAndTags(new MatchAllDocsQuery(), List.of())); + } + + @Override + List tagTypes() { + return List.of(); + } + + @Override + void checkPages(int numDocs, int limit, List results) { + long count = 0; + for (Page p : results) { + assertThat(p.getBlockCount(), equalTo(2)); + checkSeen(p, equalTo(1)); + count += getCount(p); + } + if (limit < numDocs) { + assertThat(count, greaterThanOrEqualTo((long) limit)); + } else { + assertThat(count, equalTo((long) numDocs)); + } + } + }, + MATCH_0 { + @Override + List queryAndExtra() { + return List.of(new LuceneSliceQueue.QueryAndTags(SortedNumericDocValuesField.newSlowExactQuery("s", 0), List.of())); + } + + @Override + List tagTypes() { + return List.of(); + } + + @Override + void checkPages(int numDocs, int limit, List results) { + long count = 0; + for (Page p : results) { + assertThat(p.getBlockCount(), equalTo(2)); + checkSeen(p, equalTo(1)); + count += getCount(p); + } + assertThat(count, equalTo((long) Math.min(numDocs, 1))); + } + }, + MATCH_0_AND_1 { + @Override + List queryAndExtra() { + return List.of( + new LuceneSliceQueue.QueryAndTags(SortedNumericDocValuesField.newSlowExactQuery("s", 0), List.of(123)), + new LuceneSliceQueue.QueryAndTags(SortedNumericDocValuesField.newSlowExactQuery("s", 1), List.of(456)) + ); + } + + @Override + List tagTypes() { + return List.of(ElementType.INT); + } + + @Override + void checkPages(int numDocs, int limit, List results) { + Map counts = getCountsByTag(results); + MapMatcher matcher = matchesMap(); + if (numDocs > 0) { + matcher = matcher.entry(123, 1L); + } + if (numDocs > 1) { + matcher = matcher.entry(456, 1L); + } + assertMap(counts, matcher); + } + }, + LTE_100_GT_100 { + @Override + List queryAndExtra() { + return List.of( + new LuceneSliceQueue.QueryAndTags(SortedNumericDocValuesField.newSlowRangeQuery("s", 0, 100), List.of(123)), + new LuceneSliceQueue.QueryAndTags(SortedNumericDocValuesField.newSlowRangeQuery("s", 101, Long.MAX_VALUE), List.of(456)) + ); + } + + @Override + List tagTypes() { + return List.of(ElementType.INT); + } + + @Override + void checkPages(int numDocs, int limit, List results) { + Map counts = getCountsByTag(results); + MapMatcher matcher = matchesMap(); + if (limit >= numDocs) { + // The normal case - we don't abort early. + if (numDocs > 0) { + matcher = matcher.entry(123, (long) Math.min(numDocs, 101)); + } + if (numDocs > 101) { + matcher = matcher.entry(456, (long) numDocs - 101); + } + assertMap(counts, matcher); + return; + } + /* + * The abnormal case - we abort the counting early. But this is best-effort + * so we *might* have a complete count. Otherwise, we'll have lower counts. + * + */ + if (counts.containsKey(123)) { + matcher = matcher.entry(123, both(greaterThan(0L)).and(lessThanOrEqualTo((long) Math.min(numDocs, 101)))); + } + if (counts.containsKey(456)) { + matcher = matcher.entry(456, both(greaterThan(0L)).and(lessThanOrEqualTo((long) numDocs - 101))); + } + assertThat(counts.keySet(), hasSize(either(equalTo(1)).or(equalTo(2)))); + assertMap(counts, matcher); + } + }; + + abstract List queryAndExtra(); + + abstract List tagTypes(); + + abstract void checkPages(int numDocs, int limit, List results); + + // TODO check for the count of count shortcuts taken + } + + private final TestCase testCase; + private Directory directory = newDirectory(); private IndexReader reader; + public LuceneCountOperatorTests(TestCase testCase) { + this.testCase = testCase; + } + @After public void closeIndex() throws IOException { IOUtils.close(reader, directory); @@ -56,7 +217,6 @@ protected LuceneCountOperator.Factory simple(SimpleOptions options) { } private LuceneCountOperator.Factory simple(DataPartitioning dataPartitioning, int numDocs, int limit) { - boolean enableShortcut = randomBoolean(); int commitEvery = Math.max(1, numDocs / 10); try ( RandomIndexWriter writer = new RandomIndexWriter( @@ -67,13 +227,8 @@ private LuceneCountOperator.Factory simple(DataPartitioning dataPartitioning, in ) { for (int d = 0; d < numDocs; d++) { var doc = new Document(); - doc.add(new LongPoint("s", d)); + doc.add(new SortedNumericDocValuesField("s", d)); writer.addDocument(doc); - if (enableShortcut == false && randomBoolean()) { - doc = new Document(); - doc.add(new LongPoint("s", randomLongBetween(numDocs * 5L, numDocs * 10L))); - writer.addDocument(doc); - } if (d % commitEvery == 0) { writer.commit(); } @@ -84,19 +239,8 @@ private LuceneCountOperator.Factory simple(DataPartitioning dataPartitioning, in } ShardContext ctx = new LuceneSourceOperatorTests.MockShardContext(reader, 0); - final Query query; - if (enableShortcut && randomBoolean()) { - query = new MatchAllDocsQuery(); - } else { - query = LongPoint.newRangeQuery("s", 0, numDocs); - } - return new LuceneCountOperator.Factory( - List.of(ctx), - c -> List.of(new LuceneSliceQueue.QueryAndTags(query, List.of())), - dataPartitioning, - between(1, 8), - limit - ); + Function> queryFunction = c -> testCase.queryAndExtra(); + return new LuceneCountOperator.Factory(List.of(ctx), queryFunction, dataPartitioning, between(1, 8), testCase.tagTypes(), limit); } @Override @@ -162,21 +306,39 @@ private void testCount(Supplier contexts, int size, int limit) { } OperatorTestCase.runDriver(drivers); assertThat(results.size(), lessThanOrEqualTo(taskConcurrency)); - long totalCount = 0; + testCase.checkPages(size, limit, results); + } + + private static long getCount(Page p) { + LongBlock b = p.getBlock(0); + LongVector v = b.asVector(); + assertThat(v.getPositionCount(), equalTo(1)); + assertThat(v.isConstant(), equalTo(true)); + return v.getLong(0); + } + + private static void checkSeen(Page p, Matcher positionCount) { + BooleanBlock b = p.getBlock(1); + BooleanVector v = b.asVector(); + assertThat(v.getPositionCount(), positionCount); + assertThat(v.isConstant(), equalTo(true)); + assertThat(v.getBoolean(0), equalTo(true)); + } + + private static Map getCountsByTag(List results) { + Map totals = new TreeMap<>(); for (Page page : results) { - assertThat(page.getPositionCount(), is(1)); - assertThat(page.getBlockCount(), is(2)); - LongBlock lb = page.getBlock(0); - assertThat(lb.getPositionCount(), is(1)); - long count = lb.getLong(0); - assertThat(count, lessThanOrEqualTo((long) limit)); - totalCount += count; - BooleanBlock bb = page.getBlock(1); - assertTrue(bb.getBoolean(0)); - } - // We can't verify the limit - if (size <= limit) { - assertThat(totalCount, equalTo((long) size)); + assertThat(page.getBlockCount(), equalTo(3)); + checkSeen(page, greaterThanOrEqualTo(0)); + LongBlock countsBlock = page.getBlock(0); + LongVector counts = countsBlock.asVector(); + IntBlock groupsBlock = page.getBlock(2); + IntVector groups = groupsBlock.asVector(); + for (int p = 0; p < page.getPositionCount(); p++) { + long count = counts.getLong(p); + totals.compute(groups.getInt(p), (k, prev) -> prev == null ? count : (prev + count)); + } } + return totals; } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java index ca680276c8d8c..3249031824210 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java @@ -98,6 +98,22 @@ int numResults(int numDocs) { return numDocs; } }, + MATCH_0 { + @Override + List queryAndExtra() { + return List.of(new LuceneSliceQueue.QueryAndTags(SortedNumericDocValuesField.newSlowExactQuery("s", 0), List.of())); + } + + @Override + void checkPages(int numDocs, int limit, int maxPageSize, List results) { + assertThat(results, hasSize(both(greaterThanOrEqualTo(0)).and(lessThanOrEqualTo(1)))); + } + + @Override + int numResults(int numDocs) { + return Math.min(numDocs, 1); + } + }, MATCH_0_AND_1 { @Override List queryAndExtra() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 9520f68580731..093dbca0ae51f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -400,6 +400,7 @@ public LuceneCountOperator.Factory countSource(LocalExecutionPlannerContext cont querySupplier(queryBuilder), context.queryPragmas().dataPartitioning(physicalSettings.defaultDataPartitioning()), context.queryPragmas().taskConcurrency(), + List.of(), limit == null ? NO_LIMIT : (Integer) limit.fold(context.foldCtx()) ); }