Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ public static Block[] fromList(BlockFactory blockFactory, List<List<Object>> lis

/** Returns a deep copy of the given block, using the blockFactory for creating the copy block. */
public static Block deepCopyOf(Block block, BlockFactory blockFactory) {
// TODO preserve constants here.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to do this now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try (Block.Builder builder = block.elementType().newBlockBuilder(block.getPositionCount(), blockFactory)) {
builder.copyFrom(block, 0, block.getPositionCount());
builder.mvOrdering(block.mvOrdering());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasables;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

/**
Expand All @@ -32,22 +36,16 @@
* 2. a bool flag (seen) that's always true meaning that the group (all items) always exists
*/
public class LuceneCountOperator extends LuceneOperator {

private static final int PAGE_SIZE = 1;

private int totalHits = 0;
private int remainingDocs;

private final LeafCollector leafCollector;

public static class Factory extends LuceneOperator.Factory {
private final List<? extends RefCounted> shardRefCounters;
private final List<ElementType> tagTypes;

public Factory(
List<? extends ShardContext> contexts,
Function<ShardContext, List<LuceneSliceQueue.QueryAndTags>> queryFunction,
DataPartitioning dataPartitioning,
int taskConcurrency,
List<ElementType> tagTypes,
int limit
) {
super(
Expand All @@ -61,11 +59,12 @@ public Factory(
shardContext -> ScoreMode.COMPLETE_NO_SCORES
);
this.shardRefCounters = contexts;
this.tagTypes = tagTypes;
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneCountOperator(shardRefCounters, driverContext.blockFactory(), sliceQueue, limit);
return new LuceneCountOperator(shardRefCounters, driverContext.blockFactory(), sliceQueue, tagTypes, limit);
}

@Override
Expand All @@ -74,35 +73,20 @@ public String describe() {
}
}

private final List<ElementType> tagTypes;
private final Map<List<Object>, PerTagsState> tagsToState = new HashMap<>();
private int remainingDocs;

public LuceneCountOperator(
List<? extends RefCounted> shardRefCounters,
BlockFactory blockFactory,
LuceneSliceQueue sliceQueue,
List<ElementType> tagTypes,
int limit
) {
super(shardRefCounters, blockFactory, PAGE_SIZE, sliceQueue);
super(shardRefCounters, blockFactory, Integer.MAX_VALUE, sliceQueue);
this.tagTypes = tagTypes;
this.remainingDocs = limit;
this.leafCollector = new LeafCollector() {
@Override
public void setScorer(Scorable scorer) {}

@Override
public void collect(DocIdStream stream) throws IOException {
if (remainingDocs > 0) {
int count = Math.min(stream.count(), remainingDocs);
totalHits += count;
remainingDocs -= count;
}
}

@Override
public void collect(int doc) {
if (remainingDocs > 0) {
remainingDocs--;
totalHits++;
}
}
};
}

@Override
Expand All @@ -124,59 +108,133 @@ protected Page getCheckedOutput() throws IOException {
long start = System.nanoTime();
try {
final LuceneScorer scorer = getCurrentOrLoadNextScorer();
// no scorer means no more docs
if (scorer == null) {
remainingDocs = 0;
} else {
if (scorer.tags().isEmpty() == false) {
throw new UnsupportedOperationException("tags not supported by " + getClass());
}
Weight weight = scorer.weight();
var leafReaderContext = scorer.leafReaderContext();
// see org.apache.lucene.search.TotalHitCountCollector
int leafCount = weight.count(leafReaderContext);
if (leafCount != -1) {
// make sure to NOT multi count as the count _shortcut_ (which is segment wide)
// handle doc partitioning where the same leaf can be seen multiple times
// since the count is global, consider it only for the first partition and skip the rest
// SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0)
if (scorer.position() == 0) {
// check to not count over the desired number of docs/limit
var count = Math.min(leafCount, remainingDocs);
totalHits += count;
remainingDocs -= count;
}
scorer.markAsDone();
} else {
// could not apply shortcut, trigger the search
// TODO: avoid iterating all documents in multiple calls to make cancellation more responsive.
scorer.scoreNextRange(leafCollector, leafReaderContext.reader().getLiveDocs(), remainingDocs);
}
count(scorer);
}

if (remainingDocs <= 0) {
return buildResult();
}
return null;
} finally {
processingNanos += System.nanoTime() - start;
}
}

private void count(LuceneScorer scorer) throws IOException {
PerTagsState state = tagsToState.computeIfAbsent(scorer.tags(), t -> new PerTagsState());
Weight weight = scorer.weight();
var leafReaderContext = scorer.leafReaderContext();
// see org.apache.lucene.search.TotalHitCountCollector
int leafCount = weight.count(leafReaderContext);
if (leafCount != -1) {
// make sure to NOT multi count as the count _shortcut_ (which is segment wide)
// handle doc partitioning where the same leaf can be seen multiple times
// since the count is global, consider it only for the first partition and skip the rest
// SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0)
if (scorer.position() == 0) {
// check to not count over the desired number of docs/limit
var count = Math.min(leafCount, remainingDocs);
state.totalHits += count;
remainingDocs -= count;
}
scorer.markAsDone();
} else {
// could not apply shortcut, trigger the search
// TODO: avoid iterating all documents in multiple calls to make cancellation more responsive.
scorer.scoreNextRange(state, leafReaderContext.reader().getLiveDocs(), remainingDocs);
}
}

Page page = null;
// emit only one page
if (remainingDocs <= 0 && pagesEmitted == 0) {
LongBlock count = null;
BooleanBlock seen = null;
try {
count = blockFactory.newConstantLongBlockWith(totalHits, PAGE_SIZE);
seen = blockFactory.newConstantBooleanBlockWith(true, PAGE_SIZE);
page = new Page(PAGE_SIZE, count, seen);
} finally {
if (page == null) {
Releasables.closeExpectNoException(count, seen);
}
private Page buildResult() {
return switch (tagsToState.size()) {
case 0 -> null;
case 1 -> {
Map.Entry<List<Object>, PerTagsState> e = tagsToState.entrySet().iterator().next();
yield buildConstantBlocksResult(e.getKey(), e.getValue());
}
default -> buildNonConstantBlocksResult();
};
}

private Page buildConstantBlocksResult(List<Object> tags, PerTagsState state) {
Block[] blocks = new Block[2 + tagTypes.size()];
int b = 0;
try {
blocks[b++] = blockFactory.newConstantLongBlockWith(state.totalHits, 1);
blocks[b++] = blockFactory.newConstantBooleanBlockWith(true, 1);
for (Object e : tags) {
blocks[b++] = BlockUtils.constantBlock(blockFactory, e, 1);
}
Page page = new Page(1, blocks);
blocks = null;
return page;
} finally {
if (blocks != null) {
Releasables.closeExpectNoException(blocks);
}
}
}

private Page buildNonConstantBlocksResult() {
BlockUtils.BuilderWrapper[] builders = new BlockUtils.BuilderWrapper[tagTypes.size()];
Block[] blocks = new Block[2 + tagTypes.size()];
try (LongVector.Builder countBuilder = blockFactory.newLongVectorBuilder(tagsToState.size())) {
int b = 0;
for (ElementType t : tagTypes) {
builders[b++] = BlockUtils.wrapperFor(blockFactory, t, tagsToState.size());
}

for (Map.Entry<List<Object>, PerTagsState> e : tagsToState.entrySet()) {
countBuilder.appendLong(e.getValue().totalHits);
b = 0;
for (Object t : e.getKey()) {
builders[b++].accept(t);
}
}

blocks[0] = countBuilder.build().asBlock();
blocks[1] = blockFactory.newConstantBooleanBlockWith(true, tagsToState.size());
for (b = 0; b < builders.length; b++) {
blocks[2 + b] = builders[b].builder().build();
builders[b++] = null;
}
Page page = new Page(tagsToState.size(), blocks);
blocks = null;
return page;
} finally {
processingNanos += System.nanoTime() - start;
Releasables.closeExpectNoException(Releasables.wrap(builders), blocks == null ? () -> {} : Releasables.wrap(blocks));
}
}

@Override
protected void describe(StringBuilder sb) {
sb.append(", remainingDocs=").append(remainingDocs);
}

private class PerTagsState implements LeafCollector {
long totalHits;

@Override
public void setScorer(Scorable scorer) {}

@Override
public void collect(DocIdStream stream) throws IOException {
if (remainingDocs > 0) {
int count = Math.min(stream.count(), remainingDocs);
totalHits += count;
remainingDocs -= count;
}
}

@Override
public void collect(int doc) {
if (remainingDocs > 0) {
remainingDocs--;
totalHits++;
}
}
}
}
Loading