Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/123296.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 123296
summary: Avoid over collecting in Limit or Lucene Operator
area: ES|QL
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Limiter;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.core.Releasables;

Expand All @@ -37,6 +37,7 @@ public class LuceneSourceOperator extends LuceneOperator {

private int currentPagePos = 0;
private int remainingDocs;
private final Limiter limiter;

private IntVector.Builder docsBuilder;
private DoubleVector.Builder scoreBuilder;
Expand All @@ -46,6 +47,7 @@ public class LuceneSourceOperator extends LuceneOperator {
public static class Factory extends LuceneOperator.Factory {

private final int maxPageSize;
private final Limiter limiter;

public Factory(
List<? extends ShardContext> contexts,
Expand All @@ -58,11 +60,13 @@ public Factory(
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES);
this.maxPageSize = maxPageSize;
// TODO: use a single limiter for multiple stage execution
this.limiter = limit == NO_LIMIT ? Limiter.NO_LIMIT : new Limiter(limit);
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, scoreMode);
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, scoreMode);
}

public int maxPageSize() {
Expand All @@ -84,10 +88,18 @@ public String describe() {
}

@SuppressWarnings("this-escape")
public LuceneSourceOperator(BlockFactory blockFactory, int maxPageSize, LuceneSliceQueue sliceQueue, int limit, ScoreMode scoreMode) {
public LuceneSourceOperator(
BlockFactory blockFactory,
int maxPageSize,
LuceneSliceQueue sliceQueue,
int limit,
Limiter limiter,
ScoreMode scoreMode
) {
super(blockFactory, maxPageSize, sliceQueue);
this.minPageSize = Math.max(1, maxPageSize / 2);
this.remainingDocs = limit;
this.limiter = limiter;
int estimatedSize = Math.min(limit, maxPageSize);
boolean success = false;
try {
Expand Down Expand Up @@ -140,7 +152,7 @@ public void collect(int doc) throws IOException {

@Override
public boolean isFinished() {
return doneCollecting || remainingDocs <= 0;
return doneCollecting || limiter.remaining() == 0;
}

@Override
Expand All @@ -160,6 +172,7 @@ public Page getCheckedOutput() throws IOException {
if (scorer == null) {
return null;
}
final int remainingDocsStart = remainingDocs = limiter.remaining();
try {
scorer.scoreNextRange(
leafCollector,
Expand All @@ -171,28 +184,32 @@ public Page getCheckedOutput() throws IOException {
);
} catch (CollectionTerminatedException ex) {
// The leaf collector terminated the execution
doneCollecting = true;
scorer.markAsDone();
}
final int collectedDocs = remainingDocsStart - remainingDocs;
final int discardedDocs = collectedDocs - limiter.tryAccumulateHits(collectedDocs);
Page page = null;
if (currentPagePos >= minPageSize || remainingDocs <= 0 || scorer.isDone()) {
IntBlock shard = null;
IntBlock leaf = null;
if (currentPagePos >= minPageSize || scorer.isDone() || (remainingDocs = limiter.remaining()) == 0) {
IntVector shard = null;
IntVector leaf = null;
IntVector docs = null;
DoubleVector scores = null;
DocBlock docBlock = null;
currentPagePos -= discardedDocs;
try {
shard = blockFactory.newConstantIntBlockWith(scorer.shardContext().index(), currentPagePos);
leaf = blockFactory.newConstantIntBlockWith(scorer.leafReaderContext().ord, currentPagePos);
docs = docsBuilder.build();
shard = blockFactory.newConstantIntVector(scorer.shardContext().index(), currentPagePos);
leaf = blockFactory.newConstantIntVector(scorer.leafReaderContext().ord, currentPagePos);
docs = buildDocsVector(currentPagePos);
docsBuilder = blockFactory.newIntVectorBuilder(Math.min(remainingDocs, maxPageSize));
docBlock = new DocVector(shard.asVector(), leaf.asVector(), docs, true).asBlock();
docBlock = new DocVector(shard, leaf, docs, true).asBlock();
shard = null;
leaf = null;
docs = null;
if (scoreBuilder == null) {
page = new Page(currentPagePos, docBlock);
} else {
scores = scoreBuilder.build();
scores = buildScoresVector(currentPagePos);
scoreBuilder = blockFactory.newDoubleVectorBuilder(Math.min(remainingDocs, maxPageSize));
page = new Page(currentPagePos, docBlock, scores.asBlock());
}
Expand All @@ -209,6 +226,36 @@ public Page getCheckedOutput() throws IOException {
}
}

private IntVector buildDocsVector(int upToPositions) {
final IntVector docs = docsBuilder.build();
assert docs.getPositionCount() >= upToPositions : docs.getPositionCount() + " < " + upToPositions;
if (docs.getPositionCount() == upToPositions) {
return docs;
}
try (var slice = blockFactory.newIntVectorFixedBuilder(upToPositions)) {
for (int i = 0; i < upToPositions; i++) {
slice.appendInt(docs.getInt(i));
}
docs.close();
return slice.build();
}
}

private DoubleVector buildScoresVector(int upToPositions) {
final DoubleVector scores = scoreBuilder.build();
assert scores.getPositionCount() >= upToPositions : scores.getPositionCount() + " < " + upToPositions;
if (scores.getPositionCount() == upToPositions) {
return scores;
}
try (var slice = blockFactory.newDoubleVectorBuilder(upToPositions)) {
for (int i = 0; i < upToPositions; i++) {
slice.appendDouble(scores.getDouble(i));
}
scores.close();
return slice.build();
}
}

@Override
public void close() {
Releasables.close(docsBuilder, scoreBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@
import java.util.Objects;

public class LimitOperator implements Operator {
/**
* Total number of position that are emitted by this operator.
*/
private final int limit;

/**
* Remaining number of positions that will be emitted by this operator.
*/
private int limitRemaining;

/**
* Count of pages that have been processed by this operator.
Expand All @@ -49,35 +40,49 @@ public class LimitOperator implements Operator {

private Page lastInput;

private final Limiter limiter;
private boolean finished;

public LimitOperator(int limit) {
this.limit = this.limitRemaining = limit;
public LimitOperator(Limiter limiter) {
this.limiter = limiter;
}

public record Factory(int limit) implements OperatorFactory {
public static final class Factory implements OperatorFactory {
private final Limiter limiter;

public Factory(int limit) {
this.limiter = new Limiter(limit);
}

@Override
public LimitOperator get(DriverContext driverContext) {
return new LimitOperator(limit);
return new LimitOperator(limiter);
}

@Override
public String describe() {
return "LimitOperator[limit = " + limit + "]";
return "LimitOperator[limit = " + limiter.limit() + "]";
}
}

@Override
public boolean needsInput() {
return finished == false && lastInput == null;
return finished == false && lastInput == null && limiter.remaining() > 0;
}

@Override
public void addInput(Page page) {
assert lastInput == null : "has pending input page";
lastInput = page;
rowsReceived += page.getPositionCount();
final int acceptedRows = limiter.tryAccumulateHits(page.getPositionCount());
if (acceptedRows == 0) {
page.releaseBlocks();
assert isFinished();
} else if (acceptedRows < page.getPositionCount()) {
lastInput = truncatePage(page, acceptedRows);
} else {
lastInput = page;
}
rowsReceived += acceptedRows;
}

@Override
Expand All @@ -87,55 +92,46 @@ public void finish() {

@Override
public boolean isFinished() {
return finished && lastInput == null;
return lastInput == null && (finished || limiter.remaining() == 0);
}

@Override
public Page getOutput() {
if (lastInput == null) {
return null;
}

Page result;
if (lastInput.getPositionCount() <= limitRemaining) {
result = lastInput;
limitRemaining -= lastInput.getPositionCount();
} else {
int[] filter = new int[limitRemaining];
for (int i = 0; i < limitRemaining; i++) {
filter[i] = i;
}
Block[] blocks = new Block[lastInput.getBlockCount()];
boolean success = false;
try {
for (int b = 0; b < blocks.length; b++) {
blocks[b] = lastInput.getBlock(b).filter(filter);
}
success = true;
} finally {
if (success == false) {
Releasables.closeExpectNoException(lastInput::releaseBlocks, Releasables.wrap(blocks));
} else {
lastInput.releaseBlocks();
}
lastInput = null;
}
result = new Page(blocks);
limitRemaining = 0;
}
if (limitRemaining == 0) {
finished = true;
}
final Page result = lastInput;
lastInput = null;
pagesProcessed++;
rowsEmitted += result.getPositionCount();
return result;
}

private static Page truncatePage(Page page, int upTo) {
int[] filter = new int[upTo];
for (int i = 0; i < upTo; i++) {
filter[i] = i;
}
final Block[] blocks = new Block[page.getBlockCount()];
Page result = null;
try {
for (int b = 0; b < blocks.length; b++) {
blocks[b] = page.getBlock(b).filter(filter);
}
result = new Page(blocks);
} finally {
if (result == null) {
Releasables.closeExpectNoException(page::releaseBlocks, Releasables.wrap(blocks));
} else {
page.releaseBlocks();
}
}
return result;
}

@Override
public Status status() {
return new Status(limit, limitRemaining, pagesProcessed, rowsReceived, rowsEmitted);
return new Status(limiter.limit(), limiter.remaining(), pagesProcessed, rowsReceived, rowsEmitted);
}

@Override
Expand All @@ -147,6 +143,8 @@ public void close() {

@Override
public String toString() {
final int limitRemaining = limiter.remaining();
final int limit = limiter.limit();
return "LimitOperator[limit = " + limitRemaining + "/" + limit + "]";
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.operator;

import java.util.concurrent.atomic.AtomicInteger;

/**
* A shared limiter used by multiple drivers to collect hits in parallel without exceeding the output limit.
* For example, if the query `FROM test-1,test-2 | LIMIT 100` is run with two drivers, and one driver (e.g., querying `test-1`)
* has collected 60 hits, then the other driver querying `test-2` should collect at most 40 hits.
*/
public class Limiter {
private final int limit;
private final AtomicInteger collected = new AtomicInteger();

public static Limiter NO_LIMIT = new Limiter(Integer.MAX_VALUE) {
@Override
public int tryAccumulateHits(int numHits) {
return numHits;
}

@Override
public int remaining() {
return Integer.MAX_VALUE;
}
};

public Limiter(int limit) {
this.limit = limit;
}

/**
* Returns the remaining number of hits that can be collected.
*/
public int remaining() {
final int remaining = limit - collected.get();
assert remaining >= 0 : remaining;
return remaining;
}

/**
* Returns the limit of this limiter.
*/
public int limit() {
return limit;
}

/**
* Tries to accumulate hits and returns the number of hits that has been accepted.
*
* @param numHits the number of hits to try to accumulate
* @return the accepted number of hits. If the returned number is less than the numHits,
* it means the limit has been reached and the difference can be discarded.
*/
public int tryAccumulateHits(int numHits) {
while (true) {
int curVal = collected.get();
if (curVal >= limit) {
return 0;
}
final int toAccept = Math.min(limit - curVal, numHits);
if (collected.compareAndSet(curVal, curVal + toAccept)) {
return toAccept;
}
}
}
}
Loading