Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;

Expand All @@ -44,12 +42,12 @@
* It's much faster to push queries to the {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
* this class is here to save the day.
*/
public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements Releasable {
public abstract class LuceneQueryEvaluator<T extends Block.Builder> implements Releasable {

public record ShardConfig(Query query, IndexSearcher searcher) {}

private final BlockFactory blockFactory;
private final ShardConfig[] shards;
protected final ShardConfig[] shards;

private final List<ShardState> perShardState;

Expand All @@ -67,9 +65,9 @@ public Block executeQuery(Page page) {
DocVector docs = (DocVector) block.asVector();
try {
if (docs.singleSegmentNonDecreasing()) {
return evalSingleSegmentNonDecreasing(docs).asBlock();
return evalSingleSegmentNonDecreasing(docs);
} else {
return evalSlow(docs).asBlock();
return evalSlow(docs);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
Expand Down Expand Up @@ -106,15 +104,15 @@ public Block executeQuery(Page page) {
* common.
* </p>
*/
private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
private Block evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
ShardState shardState = shardState(docs.shards().getInt(0));
SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
int min = docs.docs().getInt(0);
int max = docs.docs().getInt(docs.getPositionCount() - 1);
int length = max - min + 1;
try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) {
try (T scoreBuilder = createBlockBuilder(blockFactory, docs.getPositionCount())) {
if (length == docs.getPositionCount() && length > 1) {
return segmentState.scoreDense(scoreBuilder, min, max);
return segmentState.scoreDense(scoreBuilder, min, max, docs.getPositionCount());
}
return segmentState.scoreSparse(scoreBuilder, docs.docs());
}
Expand All @@ -134,13 +132,13 @@ private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException
* the order that the {@link DocVector} came in.
* </p>
*/
private Vector evalSlow(DocVector docs) throws IOException {
private Block evalSlow(DocVector docs) throws IOException {
int[] map = docs.shardSegmentDocMapForwards();
// Clear any state flags from the previous run
int prevShard = -1;
int prevSegment = -1;
SegmentState segmentState = null;
try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) {
try (T scoreBuilder = createBlockBuilder(blockFactory, docs.getPositionCount())) {
for (int i = 0; i < docs.getPositionCount(); i++) {
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
int segment = docs.segments().getInt(map[i]);
Expand All @@ -156,7 +154,7 @@ private Vector evalSlow(DocVector docs) throws IOException {
segmentState.scoreSingleDocWithScorer(scoreBuilder, docs.docs().getInt(map[i]));
}
}
try (Vector outOfOrder = scoreBuilder.build()) {
try (Block outOfOrder = scoreBuilder.build()) {
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
}
}
Expand Down Expand Up @@ -247,9 +245,9 @@ private SegmentState(Weight weight, LeafReaderContext ctx) {
* Score a range using the {@link BulkScorer}. This should be faster
* than using {@link #scoreSparse} for dense doc ids.
*/
Vector scoreDense(T scoreBuilder, int min, int max) throws IOException {
Block scoreDense(T scoreBuilder, int min, int max, int positionCount) throws IOException {
if (noMatch) {
return createNoMatchVector(blockFactory, max - min + 1);
return createNoMatchBlock(blockFactory, max - min + 1);
}
if (bulkScorer == null || // The bulkScorer wasn't initialized
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
Expand All @@ -258,19 +256,22 @@ Vector scoreDense(T scoreBuilder, int min, int max) throws IOException {
bulkScorer = weight.bulkScorer(ctx);
if (bulkScorer == null) {
noMatch = true;
return createNoMatchVector(blockFactory, max - min + 1);
return createNoMatchBlock(blockFactory, positionCount);
}
}
try (
DenseCollector<T> collector = new DenseCollector<>(
min,
max,
scoreBuilder,
ctx,
LuceneQueryEvaluator.this::appendNoMatch,
LuceneQueryEvaluator.this::appendMatch
(builder, scorer1, docId, ctc, query) -> LuceneQueryEvaluator.this.appendMatch(builder, scorer1, docId, ctx, query),
weight.getQuery()
)
) {
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
collector.finish();
return collector.build();
}
}
Expand All @@ -279,10 +280,10 @@ Vector scoreDense(T scoreBuilder, int min, int max) throws IOException {
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
* doc ids it'd be faster to use {@link #scoreDense}.
*/
Vector scoreSparse(T scoreBuilder, IntVector docs) throws IOException {
Block scoreSparse(T scoreBuilder, IntVector docs) throws IOException {
initScorer(docs.getInt(0));
if (noMatch) {
return createNoMatchVector(blockFactory, docs.getPositionCount());
return createNoMatchBlock(blockFactory, docs.getPositionCount());
}
for (int i = 0; i < docs.getPositionCount(); i++) {
scoreSingleDocWithScorer(scoreBuilder, docs.getInt(i));
Expand All @@ -308,29 +309,36 @@ private void initScorer(int minDocId) throws IOException {

private void scoreSingleDocWithScorer(T builder, int doc) throws IOException {
if (scorer.iterator().docID() == doc) {
appendMatch(builder, scorer);
appendMatch(builder, scorer, doc, ctx, weight.getQuery());
} else if (scorer.iterator().docID() > doc) {
appendNoMatch(builder);
} else {
if (scorer.iterator().advance(doc) == doc) {
appendMatch(builder, scorer);
appendMatch(builder, scorer, doc, ctx, weight.getQuery());
} else {
appendNoMatch(builder);
}
}
}
}

@FunctionalInterface
public interface MatchAppender<T, U, E extends Exception> {
void accept(T t, U u, int docId, LeafReaderContext leafReaderContext, Query query) throws E;
}

/**
* Collects matching information for dense range of doc ids. This assumes that
* doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
* which isn't documented, but @jpountz swears is true.
*/
static class DenseCollector<U extends Vector.Builder> implements LeafCollector, Releasable {
static class DenseCollector<U extends Block.Builder> implements LeafCollector, Releasable {
private final U scoreBuilder;
private final int max;
private final LeafReaderContext leafReaderContext;
private final Consumer<U> appendNoMatch;
private final CheckedBiConsumer<U, Scorable, IOException> appendMatch;
private final MatchAppender<U, Scorable, IOException> appendMatch;
private final Query query;

private Scorable scorer;
int next;
Expand All @@ -339,14 +347,18 @@ static class DenseCollector<U extends Vector.Builder> implements LeafCollector,
int min,
int max,
U scoreBuilder,
LeafReaderContext leafReaderContext,
Consumer<U> appendNoMatch,
CheckedBiConsumer<U, Scorable, IOException> appendMatch
MatchAppender<U, Scorable, IOException> appendMatch,
Query query
) {
this.scoreBuilder = scoreBuilder;
this.max = max;
next = min;
this.leafReaderContext = leafReaderContext;
this.appendNoMatch = appendNoMatch;
this.appendMatch = appendMatch;
this.query = query;
}

@Override
Expand All @@ -359,10 +371,10 @@ public void collect(int doc) throws IOException {
while (next++ < doc) {
appendNoMatch.accept(scoreBuilder);
}
appendMatch.accept(scoreBuilder, scorer);
appendMatch.accept(scoreBuilder, scorer, doc, leafReaderContext, query);
}

public Vector build() {
public Block build() {
return scoreBuilder.build();
}

Expand All @@ -387,17 +399,18 @@ public void close() {
/**
* Creates a vector where all positions correspond to elements that don't match the query
*/
protected abstract Vector createNoMatchVector(BlockFactory blockFactory, int size);
protected abstract Block createNoMatchBlock(BlockFactory blockFactory, int size);

/**
* Creates the corresponding vector builder to store the results of evaluating the query
*/
protected abstract T createVectorBuilder(BlockFactory blockFactory, int size);
protected abstract T createBlockBuilder(BlockFactory blockFactory, int size);

/**
* Appends a matching result to a builder created by @link createVectorBuilder}
*/
protected abstract void appendMatch(T builder, Scorable scorer) throws IOException;
protected abstract void appendMatch(T builder, Scorable scorer, int docId, LeafReaderContext leafReaderContext, Query query)
throws IOException;

/**
* Appends a non matching result to a builder created by @link createVectorBuilder}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

package org.elasticsearch.compute.lucene;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;

Expand All @@ -26,9 +27,7 @@
* a {@link BooleanVector}.
* @see LuceneQueryScoreEvaluator
*/
public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator<BooleanVector.Builder>
implements
EvalOperator.ExpressionEvaluator {
public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator<BooleanBlock.Builder> implements EvalOperator.ExpressionEvaluator {

LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
super(blockFactory, shards);
Expand All @@ -45,22 +44,23 @@ protected ScoreMode scoreMode() {
}

@Override
protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
return blockFactory.newConstantBooleanVector(false, size);
protected Block createNoMatchBlock(BlockFactory blockFactory, int size) {
return blockFactory.newConstantBooleanBlockWith(false, size);
}

@Override
protected BooleanVector.Builder createVectorBuilder(BlockFactory blockFactory, int size) {
return blockFactory.newBooleanVectorFixedBuilder(size);
protected BooleanBlock.Builder createBlockBuilder(BlockFactory blockFactory, int size) {
return blockFactory.newBooleanBlockBuilder(size);
}

@Override
protected void appendNoMatch(BooleanVector.Builder builder) {
protected void appendNoMatch(BooleanBlock.Builder builder) {
builder.appendBoolean(false);
}

@Override
protected void appendMatch(BooleanVector.Builder builder, Scorable scorer) throws IOException {
protected void appendMatch(BooleanBlock.Builder builder, Scorable scorer, int docId, LeafReaderContext leafReaderContext, Query query)
throws IOException {
builder.appendBoolean(true);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

package org.elasticsearch.compute.lucene;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.ScoreOperator;

Expand All @@ -27,7 +27,7 @@
* Elements that don't match will have a score of {@link #NO_MATCH_SCORE}.
* @see LuceneQueryScoreEvaluator
*/
public class LuceneQueryScoreEvaluator extends LuceneQueryEvaluator<DoubleVector.Builder> implements ScoreOperator.ExpressionScorer {
public class LuceneQueryScoreEvaluator extends LuceneQueryEvaluator<DoubleBlock.Builder> implements ScoreOperator.ExpressionScorer {

public static final double NO_MATCH_SCORE = 0.0;

Expand All @@ -46,22 +46,23 @@ protected ScoreMode scoreMode() {
}

@Override
protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, size);
protected DoubleBlock createNoMatchBlock(BlockFactory blockFactory, int size) {
return blockFactory.newConstantDoubleBlockWith(NO_MATCH_SCORE, size);
}

@Override
protected DoubleVector.Builder createVectorBuilder(BlockFactory blockFactory, int size) {
return blockFactory.newDoubleVectorFixedBuilder(size);
protected DoubleBlock.Builder createBlockBuilder(BlockFactory blockFactory, int size) {
return blockFactory.newDoubleBlockBuilder(size);
}

@Override
protected void appendNoMatch(DoubleVector.Builder builder) {
protected void appendNoMatch(DoubleBlock.Builder builder) {
builder.appendDouble(NO_MATCH_SCORE);
}

@Override
protected void appendMatch(DoubleVector.Builder builder, Scorable scorer) throws IOException {
protected void appendMatch(DoubleBlock.Builder builder, Scorable scorer, int docId, LeafReaderContext leafReaderContext, Query query)
throws IOException {
builder.appendDouble(scorer.score());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.compute.OperatorTests;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.compute.lucene.read.ValuesSourceReaderOperator;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
Expand All @@ -59,7 +58,7 @@
/**
* Base class for testing Lucene query evaluators.
*/
public abstract class LuceneQueryEvaluatorTests<T extends Vector, U extends Vector.Builder> extends ComputeTestCase {
public abstract class LuceneQueryEvaluatorTests<T extends Block, U extends Block.Builder> extends ComputeTestCase {

private static final String FIELD = "g";

Expand Down Expand Up @@ -168,9 +167,9 @@ protected void assertTermsQuery(List<Page> results, Set<String> matching, int ex
int matchCount = 0;
for (Page page : results) {
int initialBlockIndex = termsBlockIndex(page);
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
BytesRefBlock terms = page.<BytesRefBlock>getBlock(initialBlockIndex);
@SuppressWarnings("unchecked")
T resultVector = (T) page.getBlock(resultsBlockIndex(page)).asVector();
T resultVector = (T) page.getBlock(resultsBlockIndex(page));
for (int i = 0; i < page.getPositionCount(); i++) {
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
boolean isMatch = matching.contains(termAtPosition.utf8ToString());
Expand Down
Loading