diff --git a/docs/changelog/123296.yaml b/docs/changelog/123296.yaml new file mode 100644 index 0000000000000..1dd32d21294fb --- /dev/null +++ b/docs/changelog/123296.yaml @@ -0,0 +1,5 @@ +pr: 123296 +summary: Avoid over collecting in Limit or Lucene Operator +area: ES|QL +type: bug +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java index 61a7cbad3e8af..c570d05e9b2ae 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java @@ -16,10 +16,10 @@ import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.Limiter; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.core.Releasables; @@ -37,6 +37,7 @@ public class LuceneSourceOperator extends LuceneOperator { private int currentPagePos = 0; private int remainingDocs; + private final Limiter limiter; private IntVector.Builder docsBuilder; private DoubleVector.Builder scoreBuilder; @@ -46,6 +47,7 @@ public class LuceneSourceOperator extends LuceneOperator { public static class Factory extends LuceneOperator.Factory { private final int maxPageSize; + private final Limiter limiter; public Factory( List contexts, @@ -58,11 +60,13 @@ public Factory( ) { super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES); this.maxPageSize = maxPageSize; + // TODO: use a single limiter for multiple stage execution + this.limiter = limit == NO_LIMIT ? Limiter.NO_LIMIT : new Limiter(limit); } @Override public SourceOperator get(DriverContext driverContext) { - return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, scoreMode); + return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, scoreMode); } public int maxPageSize() { @@ -84,10 +88,18 @@ public String describe() { } @SuppressWarnings("this-escape") - public LuceneSourceOperator(BlockFactory blockFactory, int maxPageSize, LuceneSliceQueue sliceQueue, int limit, ScoreMode scoreMode) { + public LuceneSourceOperator( + BlockFactory blockFactory, + int maxPageSize, + LuceneSliceQueue sliceQueue, + int limit, + Limiter limiter, + ScoreMode scoreMode + ) { super(blockFactory, maxPageSize, sliceQueue); this.minPageSize = Math.max(1, maxPageSize / 2); this.remainingDocs = limit; + this.limiter = limiter; int estimatedSize = Math.min(limit, maxPageSize); boolean success = false; try { @@ -140,7 +152,7 @@ public void collect(int doc) throws IOException { @Override public boolean isFinished() { - return doneCollecting || remainingDocs <= 0; + return doneCollecting || limiter.remaining() == 0; } @Override @@ -160,6 +172,7 @@ public Page getCheckedOutput() throws IOException { if (scorer == null) { return null; } + final int remainingDocsStart = remainingDocs = limiter.remaining(); try { scorer.scoreNextRange( leafCollector, @@ -171,28 +184,32 @@ public Page getCheckedOutput() throws IOException { ); } catch (CollectionTerminatedException ex) { // The leaf collector terminated the execution + doneCollecting = true; scorer.markAsDone(); } + final int collectedDocs = remainingDocsStart - remainingDocs; + final int discardedDocs = collectedDocs - limiter.tryAccumulateHits(collectedDocs); Page page = null; - if (currentPagePos >= minPageSize || remainingDocs <= 0 || scorer.isDone()) { - IntBlock shard = null; - IntBlock leaf = null; + if (currentPagePos >= minPageSize || scorer.isDone() || (remainingDocs = limiter.remaining()) == 0) { + IntVector shard = null; + IntVector leaf = null; IntVector docs = null; DoubleVector scores = null; DocBlock docBlock = null; + currentPagePos -= discardedDocs; try { - shard = blockFactory.newConstantIntBlockWith(scorer.shardContext().index(), currentPagePos); - leaf = blockFactory.newConstantIntBlockWith(scorer.leafReaderContext().ord, currentPagePos); - docs = docsBuilder.build(); + shard = blockFactory.newConstantIntVector(scorer.shardContext().index(), currentPagePos); + leaf = blockFactory.newConstantIntVector(scorer.leafReaderContext().ord, currentPagePos); + docs = buildDocsVector(currentPagePos); docsBuilder = blockFactory.newIntVectorBuilder(Math.min(remainingDocs, maxPageSize)); - docBlock = new DocVector(shard.asVector(), leaf.asVector(), docs, true).asBlock(); + docBlock = new DocVector(shard, leaf, docs, true).asBlock(); shard = null; leaf = null; docs = null; if (scoreBuilder == null) { page = new Page(currentPagePos, docBlock); } else { - scores = scoreBuilder.build(); + scores = buildScoresVector(currentPagePos); scoreBuilder = blockFactory.newDoubleVectorBuilder(Math.min(remainingDocs, maxPageSize)); page = new Page(currentPagePos, docBlock, scores.asBlock()); } @@ -209,6 +226,36 @@ public Page getCheckedOutput() throws IOException { } } + private IntVector buildDocsVector(int upToPositions) { + final IntVector docs = docsBuilder.build(); + assert docs.getPositionCount() >= upToPositions : docs.getPositionCount() + " < " + upToPositions; + if (docs.getPositionCount() == upToPositions) { + return docs; + } + try (var slice = blockFactory.newIntVectorFixedBuilder(upToPositions)) { + for (int i = 0; i < upToPositions; i++) { + slice.appendInt(docs.getInt(i)); + } + docs.close(); + return slice.build(); + } + } + + private DoubleVector buildScoresVector(int upToPositions) { + final DoubleVector scores = scoreBuilder.build(); + assert scores.getPositionCount() >= upToPositions : scores.getPositionCount() + " < " + upToPositions; + if (scores.getPositionCount() == upToPositions) { + return scores; + } + try (var slice = blockFactory.newDoubleVectorBuilder(upToPositions)) { + for (int i = 0; i < upToPositions; i++) { + slice.appendDouble(scores.getDouble(i)); + } + scores.close(); + return slice.build(); + } + } + @Override public void close() { Releasables.close(docsBuilder, scoreBuilder); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java index b669be9192d06..3ef9c420f59ff 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java @@ -22,15 +22,6 @@ import java.util.Objects; public class LimitOperator implements Operator { - /** - * Total number of position that are emitted by this operator. - */ - private final int limit; - - /** - * Remaining number of positions that will be emitted by this operator. - */ - private int limitRemaining; /** * Count of pages that have been processed by this operator. @@ -49,35 +40,49 @@ public class LimitOperator implements Operator { private Page lastInput; + private final Limiter limiter; private boolean finished; - public LimitOperator(int limit) { - this.limit = this.limitRemaining = limit; + public LimitOperator(Limiter limiter) { + this.limiter = limiter; } - public record Factory(int limit) implements OperatorFactory { + public static final class Factory implements OperatorFactory { + private final Limiter limiter; + + public Factory(int limit) { + this.limiter = new Limiter(limit); + } @Override public LimitOperator get(DriverContext driverContext) { - return new LimitOperator(limit); + return new LimitOperator(limiter); } @Override public String describe() { - return "LimitOperator[limit = " + limit + "]"; + return "LimitOperator[limit = " + limiter.limit() + "]"; } } @Override public boolean needsInput() { - return finished == false && lastInput == null; + return finished == false && lastInput == null && limiter.remaining() > 0; } @Override public void addInput(Page page) { assert lastInput == null : "has pending input page"; - lastInput = page; - rowsReceived += page.getPositionCount(); + final int acceptedRows = limiter.tryAccumulateHits(page.getPositionCount()); + if (acceptedRows == 0) { + page.releaseBlocks(); + assert isFinished(); + } else if (acceptedRows < page.getPositionCount()) { + lastInput = truncatePage(page, acceptedRows); + } else { + lastInput = page; + } + rowsReceived += acceptedRows; } @Override @@ -87,7 +92,7 @@ public void finish() { @Override public boolean isFinished() { - return finished && lastInput == null; + return lastInput == null && (finished || limiter.remaining() == 0); } @Override @@ -95,47 +100,38 @@ public Page getOutput() { if (lastInput == null) { return null; } - - Page result; - if (lastInput.getPositionCount() <= limitRemaining) { - result = lastInput; - limitRemaining -= lastInput.getPositionCount(); - } else { - int[] filter = new int[limitRemaining]; - for (int i = 0; i < limitRemaining; i++) { - filter[i] = i; - } - Block[] blocks = new Block[lastInput.getBlockCount()]; - boolean success = false; - try { - for (int b = 0; b < blocks.length; b++) { - blocks[b] = lastInput.getBlock(b).filter(filter); - } - success = true; - } finally { - if (success == false) { - Releasables.closeExpectNoException(lastInput::releaseBlocks, Releasables.wrap(blocks)); - } else { - lastInput.releaseBlocks(); - } - lastInput = null; - } - result = new Page(blocks); - limitRemaining = 0; - } - if (limitRemaining == 0) { - finished = true; - } + final Page result = lastInput; lastInput = null; pagesProcessed++; rowsEmitted += result.getPositionCount(); + return result; + } + private static Page truncatePage(Page page, int upTo) { + int[] filter = new int[upTo]; + for (int i = 0; i < upTo; i++) { + filter[i] = i; + } + final Block[] blocks = new Block[page.getBlockCount()]; + Page result = null; + try { + for (int b = 0; b < blocks.length; b++) { + blocks[b] = page.getBlock(b).filter(filter); + } + result = new Page(blocks); + } finally { + if (result == null) { + Releasables.closeExpectNoException(page::releaseBlocks, Releasables.wrap(blocks)); + } else { + page.releaseBlocks(); + } + } return result; } @Override public Status status() { - return new Status(limit, limitRemaining, pagesProcessed, rowsReceived, rowsEmitted); + return new Status(limiter.limit(), limiter.remaining(), pagesProcessed, rowsReceived, rowsEmitted); } @Override @@ -147,6 +143,8 @@ public void close() { @Override public String toString() { + final int limitRemaining = limiter.remaining(); + final int limit = limiter.limit(); return "LimitOperator[limit = " + limitRemaining + "/" + limit + "]"; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Limiter.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Limiter.java new file mode 100644 index 0000000000000..a74a93eceec40 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Limiter.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A shared limiter used by multiple drivers to collect hits in parallel without exceeding the output limit. + * For example, if the query `FROM test-1,test-2 | LIMIT 100` is run with two drivers, and one driver (e.g., querying `test-1`) + * has collected 60 hits, then the other driver querying `test-2` should collect at most 40 hits. + */ +public class Limiter { + private final int limit; + private final AtomicInteger collected = new AtomicInteger(); + + public static Limiter NO_LIMIT = new Limiter(Integer.MAX_VALUE) { + @Override + public int tryAccumulateHits(int numHits) { + return numHits; + } + + @Override + public int remaining() { + return Integer.MAX_VALUE; + } + }; + + public Limiter(int limit) { + this.limit = limit; + } + + /** + * Returns the remaining number of hits that can be collected. + */ + public int remaining() { + final int remaining = limit - collected.get(); + assert remaining >= 0 : remaining; + return remaining; + } + + /** + * Returns the limit of this limiter. + */ + public int limit() { + return limit; + } + + /** + * Tries to accumulate hits and returns the number of hits that has been accepted. + * + * @param numHits the number of hits to try to accumulate + * @return the accepted number of hits. If the returned number is less than the numHits, + * it means the limit has been reached and the difference can be discarded. + */ + public int tryAccumulateHits(int numHits) { + while (true) { + int curVal = collected.get(); + if (curVal >= limit) { + return 0; + } + final int toAccept = Math.min(limit - curVal, numHits); + if (collected.compareAndSet(curVal, curVal + toAccept)) { + return toAccept; + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java index 42c9f49a2db7c..4a6368c5dee27 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java @@ -25,6 +25,8 @@ import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.operator.SinkOperator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.test.AnyOperatorTestCase; import org.elasticsearch.compute.test.OperatorTestCase; @@ -44,6 +46,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import static org.hamcrest.Matchers.both; @@ -94,7 +97,16 @@ private LuceneSourceOperator.Factory simple(DataPartitioning dataPartitioning, i ShardContext ctx = new MockShardContext(reader, 0); Function queryFunction = c -> new MatchAllDocsQuery(); int maxPageSize = between(10, Math.max(10, numDocs)); - return new LuceneSourceOperator.Factory(List.of(ctx), queryFunction, dataPartitioning, 1, maxPageSize, limit, scoring); + int taskConcurrency = randomIntBetween(1, 4); + return new LuceneSourceOperator.Factory( + List.of(ctx), + queryFunction, + dataPartitioning, + taskConcurrency, + maxPageSize, + limit, + scoring + ); } @Override @@ -120,23 +132,23 @@ public void testShardDataPartitioning() { public void testEarlyTermination() { int size = between(1_000, 20_000); - int limit = between(10, size); + int limit = between(0, Integer.MAX_VALUE); LuceneSourceOperator.Factory factory = simple(randomFrom(DataPartitioning.values()), size, limit, scoring); - try (SourceOperator sourceOperator = factory.get(driverContext())) { - assertFalse(sourceOperator.isFinished()); - int collected = 0; - while (sourceOperator.isFinished() == false) { - Page page = sourceOperator.getOutput(); - if (page != null) { - collected += page.getPositionCount(); - page.releaseBlocks(); - } - if (collected >= limit) { - assertTrue("source operator is not finished after reaching limit", sourceOperator.isFinished()); - assertThat(collected, equalTo(limit)); - } - } + int taskConcurrency = factory.taskConcurrency(); + final AtomicInteger receivedRows = new AtomicInteger(); + List drivers = new ArrayList<>(); + for (int i = 0; i < taskConcurrency; i++) { + DriverContext driverContext = driverContext(); + SourceOperator sourceOperator = factory.get(driverContext); + SinkOperator sinkOperator = new PageConsumerOperator(p -> { + receivedRows.addAndGet(p.getPositionCount()); + p.releaseBlocks(); + }); + Driver driver = new Driver("driver" + i, driverContext, sourceOperator, List.of(), sinkOperator, () -> {}); + drivers.add(driver); } + OperatorTestCase.runDriver(drivers); + assertThat(receivedRows.get(), equalTo(Math.min(limit, size))); } public void testEmpty() { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java index acc62de0884c2..b5fbba1b84b09 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java @@ -141,7 +141,7 @@ public void doClose() { if (randomBoolean()) { int limit = between(0, ids.size()); it = ids.subList(0, limit).iterator(); - intermediateOperators.add(new LimitOperator(limit)); + intermediateOperators.add(new LimitOperator(new Limiter(limit))); } else { it = ids.iterator(); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LimitOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LimitOperatorTests.java index b05be86a164aa..9b3d83ce3f387 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LimitOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LimitOperatorTests.java @@ -15,7 +15,9 @@ import org.elasticsearch.compute.test.SequenceLongBlockSourceOperator; import org.hamcrest.Matcher; +import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.LongStream; import static org.elasticsearch.compute.test.RandomBlock.randomElementType; @@ -126,6 +128,49 @@ public void testBlockPreciselyRemaining() { } } + public void testEarlyTermination() { + int numDrivers = between(1, 4); + final List drivers = new ArrayList<>(); + final int limit = between(1, 10_000); + final LimitOperator.Factory limitFactory = new LimitOperator.Factory(limit); + final AtomicInteger receivedRows = new AtomicInteger(); + for (int i = 0; i < numDrivers; i++) { + DriverContext driverContext = driverContext(); + SourceOperator sourceOperator = new SourceOperator() { + boolean finished = false; + + @Override + public void finish() { + finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } + + @Override + public Page getOutput() { + return new Page(randomBlock(driverContext.blockFactory(), between(1, 100))); + } + + @Override + public void close() { + + } + }; + SinkOperator sinkOperator = new PageConsumerOperator(p -> { + receivedRows.addAndGet(p.getPositionCount()); + p.releaseBlocks(); + }); + drivers.add( + new Driver("driver" + i, driverContext, sourceOperator, List.of(limitFactory.get(driverContext)), sinkOperator, () -> {}) + ); + } + runDriver(drivers); + assertThat(receivedRows.get(), equalTo(limit)); + } + Block randomBlock(BlockFactory blockFactory, int size) { if (randomBoolean()) { return blockFactory.newConstantNullBlock(size);