Skip to content

Commit 978770d

Browse files
committed
LuceneQueryScoreEvaluator first implementation
1 parent e5ea00a commit 978770d

File tree

2 files changed

+659
-0
lines changed

2 files changed

+659
-0
lines changed
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.lucene;
9+
10+
import org.apache.lucene.index.LeafReaderContext;
11+
import org.apache.lucene.search.BulkScorer;
12+
import org.apache.lucene.search.IndexSearcher;
13+
import org.apache.lucene.search.LeafCollector;
14+
import org.apache.lucene.search.Query;
15+
import org.apache.lucene.search.Scorable;
16+
import org.apache.lucene.search.ScoreMode;
17+
import org.apache.lucene.search.Scorer;
18+
import org.apache.lucene.search.Weight;
19+
import org.apache.lucene.util.ArrayUtil;
20+
import org.apache.lucene.util.Bits;
21+
import org.elasticsearch.compute.data.Block;
22+
import org.elasticsearch.compute.data.BlockFactory;
23+
import org.elasticsearch.compute.data.BooleanVector;
24+
import org.elasticsearch.compute.data.DocBlock;
25+
import org.elasticsearch.compute.data.DocVector;
26+
import org.elasticsearch.compute.data.DoubleVector;
27+
import org.elasticsearch.compute.data.IntVector;
28+
import org.elasticsearch.compute.data.Page;
29+
import org.elasticsearch.compute.operator.DriverContext;
30+
import org.elasticsearch.compute.operator.EvalOperator;
31+
import org.elasticsearch.core.Releasable;
32+
import org.elasticsearch.core.Releasables;
33+
34+
import java.io.IOException;
35+
import java.io.UncheckedIOException;
36+
37+
/**
38+
* {@link EvalOperator.ExpressionEvaluator} to run a Lucene {@link Query} during
39+
* the compute engine's normal execution, yielding matches/does not match into
40+
* a {@link BooleanVector}. It's much faster to push these to the
41+
* {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
42+
* this evaluator is here to save the day.
43+
*/
44+
public class LuceneQueryScoreEvaluator implements EvalOperator.ExpressionEvaluator {
45+
46+
public static final double NO_MATCH_SCORE = 0.0;
47+
48+
public record ShardConfig(Query query, IndexSearcher searcher) {}
49+
50+
private final BlockFactory blockFactory;
51+
private final ShardConfig[] shards;
52+
53+
private ShardState[] perShardState = EMPTY_SHARD_STATES;
54+
55+
public LuceneQueryScoreEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
56+
this.blockFactory = blockFactory;
57+
this.shards = shards;
58+
}
59+
60+
@Override
61+
public Block eval(Page page) {
62+
// Lucene based operators retrieve DocVectors as first block
63+
Block block = page.getBlock(0);
64+
assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
65+
DocVector docs = (DocVector) block.asVector();
66+
try {
67+
if (docs.singleSegmentNonDecreasing()) {
68+
return evalSingleSegmentNonDecreasing(docs).asBlock();
69+
} else {
70+
return evalSlow(docs).asBlock();
71+
}
72+
} catch (IOException e) {
73+
throw new UncheckedIOException(e);
74+
}
75+
}
76+
77+
/**
78+
* Evaluate {@link DocVector#singleSegmentNonDecreasing()} documents.
79+
* <p>
80+
* ESQL receives documents in DocVector, and they can be in one of two
81+
* states. Either the DocVector contains documents from a single segment
82+
* non-decreasing order, or it doesn't. that first case is much more like
83+
* how Lucene likes to process documents. and it's much more common. So we
84+
* optimize for it.
85+
* <p>
86+
* Vectors that are {@link DocVector#singleSegmentNonDecreasing()}
87+
* represent many documents from a single Lucene segment. In Elasticsearch
88+
* terms that's a segment in a single shard. And the document ids are in
89+
* non-decreasing order. Probably just {@code 0, 1, 2, 3, 4, 5...}.
90+
* But maybe something like {@code 0, 5, 6, 10, 10, 10}. Both of those are
91+
* very like how lucene "natively" processes documents and this optimizes
92+
* those accesses.
93+
* </p>
94+
* <p>
95+
* If the documents are literally {@code 0, 1, ... n} then we use
96+
* {@link BulkScorer#score(LeafCollector, Bits, int, int)} which processes
97+
* a whole range. This should be quite common in the case where we don't
98+
* have deleted documents because that's the order that
99+
* {@link LuceneSourceOperator} produces them.
100+
* </p>
101+
* <p>
102+
* If there are gaps in the sequence we use {@link Scorer} calls to
103+
* score the sequence. This'll be less fast but isn't going be particularly
104+
* common.
105+
* </p>
106+
*/
107+
private DoubleVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
108+
ShardState shardState = shardState(docs.shards().getInt(0));
109+
SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
110+
int min = docs.docs().getInt(0);
111+
int max = docs.docs().getInt(docs.getPositionCount() - 1);
112+
int length = max - min + 1;
113+
if (length == docs.getPositionCount() && length > 1) {
114+
return segmentState.scoreDense(min, max);
115+
}
116+
return segmentState.scoreSparse(docs.docs());
117+
}
118+
119+
/**
120+
* Evaluate non-{@link DocVector#singleSegmentNonDecreasing()} documents. See
121+
* {@link #evalSingleSegmentNonDecreasing} for the meaning of
122+
* {@link DocVector#singleSegmentNonDecreasing()} and how we can efficiently
123+
* evaluate those segments.
124+
* <p>
125+
* This processes the worst case blocks of documents. These can be from any
126+
* number of shards and any number of segments and in any order. We do this
127+
* by iterating the docs in {@code shard ASC, segment ASC, doc ASC} order.
128+
* So, that's segment by segment, docs ascending. We build a boolean block
129+
* out of that. Then we <strong>sort</strong> that to put the booleans in
130+
* the order that the {@link DocVector} came in.
131+
* </p>
132+
*/
133+
private DoubleVector evalSlow(DocVector docs) throws IOException {
134+
int[] map = docs.shardSegmentDocMapForwards();
135+
// Clear any state flags from the previous run
136+
int prevShard = -1;
137+
int prevSegment = -1;
138+
SegmentState segmentState = null;
139+
try (DoubleVector.Builder builder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())) {
140+
for (int i = 0; i < docs.getPositionCount(); i++) {
141+
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
142+
int segment = docs.segments().getInt(map[i]);
143+
if (segmentState == null || prevShard != shard || prevSegment != segment) {
144+
segmentState = shardState(shard).segmentState(segment);
145+
segmentState.initScorer(docs.docs().getInt(map[i]));
146+
prevShard = shard;
147+
prevSegment = segment;
148+
}
149+
if (segmentState.noMatch) {
150+
builder.appendDouble(NO_MATCH_SCORE);
151+
} else {
152+
segmentState.scoreSingleDocWithScorer(builder, docs.docs().getInt(map[i]));
153+
}
154+
}
155+
try (DoubleVector outOfOrder = builder.build()) {
156+
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
157+
}
158+
}
159+
}
160+
161+
@Override
162+
public void close() {
163+
164+
}
165+
166+
private ShardState shardState(int shard) throws IOException {
167+
if (shard >= perShardState.length) {
168+
perShardState = ArrayUtil.grow(perShardState, shard + 1);
169+
} else if (perShardState[shard] != null) {
170+
return perShardState[shard];
171+
}
172+
perShardState[shard] = new ShardState(shards[shard]);
173+
return perShardState[shard];
174+
}
175+
176+
private class ShardState {
177+
private final Weight weight;
178+
private final IndexSearcher searcher;
179+
private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES;
180+
181+
ShardState(ShardConfig config) throws IOException {
182+
weight = config.searcher.createWeight(config.query, ScoreMode.COMPLETE, 1.0f);
183+
searcher = config.searcher;
184+
}
185+
186+
SegmentState segmentState(int segment) throws IOException {
187+
if (segment >= perSegmentState.length) {
188+
perSegmentState = ArrayUtil.grow(perSegmentState, segment + 1);
189+
} else if (perSegmentState[segment] != null) {
190+
return perSegmentState[segment];
191+
}
192+
perSegmentState[segment] = new SegmentState(weight, searcher.getLeafContexts().get(segment));
193+
return perSegmentState[segment];
194+
}
195+
}
196+
197+
private class SegmentState {
198+
private final Weight weight;
199+
private final LeafReaderContext ctx;
200+
201+
/**
202+
* Lazily initialed {@link Scorer} for this. {@code null} here means uninitialized
203+
* <strong>or</strong> that {@link #noMatch} is true.
204+
*/
205+
private Scorer scorer;
206+
207+
/**
208+
* Thread that initialized the {@link #scorer}.
209+
*/
210+
private Thread scorerThread;
211+
212+
/**
213+
* Lazily initialed {@link BulkScorer} for this. {@code null} here means uninitialized
214+
* <strong>or</strong> that {@link #noMatch} is true.
215+
*/
216+
private BulkScorer bulkScorer;
217+
218+
/**
219+
* Thread that initialized the {@link #bulkScorer}.
220+
*/
221+
private Thread bulkScorerThread;
222+
223+
/**
224+
* Set to {@code true} if, in the process of building a {@link Scorer} or {@link BulkScorer},
225+
* the {@link Weight} tells us there aren't any matches.
226+
*/
227+
private boolean noMatch;
228+
229+
private SegmentState(Weight weight, LeafReaderContext ctx) {
230+
this.weight = weight;
231+
this.ctx = ctx;
232+
}
233+
234+
/**
235+
* Score a range using the {@link BulkScorer}. This should be faster
236+
* than using {@link #scoreSparse} for dense doc ids.
237+
*/
238+
DoubleVector scoreDense(int min, int max) throws IOException {
239+
int length = max - min + 1;
240+
if (noMatch) {
241+
return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length);
242+
}
243+
if (bulkScorer == null || // The bulkScorer wasn't initialized
244+
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
245+
) {
246+
bulkScorerThread = Thread.currentThread();
247+
bulkScorer = weight.bulkScorer(ctx);
248+
if (bulkScorer == null) {
249+
noMatch = true;
250+
return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, length);
251+
}
252+
}
253+
try (DenseCollector collector = new DenseCollector(blockFactory, min, max)) {
254+
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
255+
return collector.build();
256+
}
257+
}
258+
259+
/**
260+
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
261+
* doc ids it'd be faster to use {@link #scoreDense}.
262+
*/
263+
DoubleVector scoreSparse(IntVector docs) throws IOException {
264+
initScorer(docs.getInt(0));
265+
if (noMatch) {
266+
return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, docs.getPositionCount());
267+
}
268+
try (DoubleVector.Builder builder = blockFactory.newDoubleVectorFixedBuilder(docs.getPositionCount())) {
269+
for (int i = 0; i < docs.getPositionCount(); i++) {
270+
scoreSingleDocWithScorer(builder, docs.getInt(i));
271+
}
272+
return builder.build();
273+
}
274+
}
275+
276+
private void initScorer(int minDocId) throws IOException {
277+
if (noMatch) {
278+
return;
279+
}
280+
if (scorer == null || // Scorer not initialized
281+
scorerThread != Thread.currentThread() || // Scorer initialized on a different thread
282+
scorer.iterator().docID() > minDocId // The previous block came "after" this one
283+
) {
284+
scorerThread = Thread.currentThread();
285+
scorer = weight.scorer(ctx);
286+
if (scorer == null) {
287+
noMatch = true;
288+
}
289+
}
290+
}
291+
292+
private void scoreSingleDocWithScorer(DoubleVector.Builder builder, int doc) throws IOException {
293+
if (scorer.iterator().docID() == doc) {
294+
builder.appendDouble(scorer.score());
295+
} else if (scorer.iterator().docID() > doc) {
296+
builder.appendDouble(NO_MATCH_SCORE);
297+
} else {
298+
builder.appendDouble(scorer.iterator().advance(doc) == doc ? scorer.score() : NO_MATCH_SCORE);
299+
}
300+
}
301+
}
302+
303+
private static final ShardState[] EMPTY_SHARD_STATES = new ShardState[0];
304+
private static final SegmentState[] EMPTY_SEGMENT_STATES = new SegmentState[0];
305+
306+
/**
307+
* Collects matching information for dense range of doc ids. This assumes that
308+
* doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
309+
* which isn't documented, but @jpountz swears is true.
310+
*/
311+
static class DenseCollector implements LeafCollector, Releasable {
312+
private final DoubleVector.FixedBuilder builder;
313+
private final int max;
314+
private Scorable scorer;
315+
316+
int next;
317+
318+
DenseCollector(BlockFactory blockFactory, int min, int max) {
319+
this.builder = blockFactory.newDoubleVectorFixedBuilder(max - min + 1);
320+
this.max = max;
321+
next = min;
322+
}
323+
324+
@Override
325+
public void setScorer(Scorable scorable) {
326+
this.scorer = scorable;
327+
}
328+
329+
@Override
330+
public void collect(int doc) throws IOException {
331+
while (next++ < doc) {
332+
builder.appendDouble(NO_MATCH_SCORE);
333+
}
334+
builder.appendDouble(scorer.score());
335+
}
336+
337+
public DoubleVector build() {
338+
return builder.build();
339+
}
340+
341+
@Override
342+
public void finish() {
343+
while (next++ <= max) {
344+
builder.appendDouble(NO_MATCH_SCORE);
345+
}
346+
}
347+
348+
@Override
349+
public void close() {
350+
Releasables.closeExpectNoException(builder);
351+
}
352+
}
353+
354+
public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
355+
private final ShardConfig[] shardConfigs;
356+
357+
public Factory(ShardConfig[] shardConfigs) {
358+
this.shardConfigs = shardConfigs;
359+
}
360+
361+
@Override
362+
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
363+
return new LuceneQueryScoreEvaluator(context.blockFactory(), shardConfigs);
364+
}
365+
}
366+
}

0 commit comments

Comments
 (0)