Skip to content

Commit 63ca98b

Browse files
committed
Refactor query evaluators to use subclasses instead of interfaces
1 parent 145955c commit 63ca98b

File tree

6 files changed

+457
-719
lines changed

6 files changed

+457
-719
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java

Lines changed: 62 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import org.apache.lucene.search.ScoreMode;
1717
import org.apache.lucene.search.Scorer;
1818
import org.apache.lucene.search.Weight;
19-
import org.apache.lucene.util.ArrayUtil;
2019
import org.apache.lucene.util.Bits;
20+
import org.elasticsearch.common.CheckedBiConsumer;
2121
import org.elasticsearch.compute.data.Block;
2222
import org.elasticsearch.compute.data.BlockFactory;
2323
import org.elasticsearch.compute.data.BooleanVector;
@@ -32,7 +32,10 @@
3232

3333
import java.io.IOException;
3434
import java.io.UncheckedIOException;
35-
import java.util.function.BiFunction;
35+
import java.util.ArrayList;
36+
import java.util.Collections;
37+
import java.util.List;
38+
import java.util.function.Consumer;
3639

3740
/**
3841
* {@link EvalOperator.ExpressionEvaluator} to run a Lucene {@link Query} during
@@ -41,26 +44,22 @@
4144
* {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
4245
* this evaluator is here to save the day.
4346
*/
44-
public abstract class LuceneQueryEvaluator implements Releasable {
45-
46-
public static final double NO_MATCH_SCORE = 0.0;
47+
public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements Releasable {
4748

4849
public record ShardConfig(Query query, IndexSearcher searcher) {}
4950

5051
private final BlockFactory blockFactory;
5152
private final ShardConfig[] shards;
52-
private final BiFunction<BlockFactory, Integer, ScoreVectorBuilder> scoreVectorBuilderSupplier;
5353

54-
private ShardState[] perShardState = EMPTY_SHARD_STATES;
54+
private final List<ShardState> perShardState;
5555

5656
protected LuceneQueryEvaluator(
5757
BlockFactory blockFactory,
58-
ShardConfig[] shards,
59-
BiFunction<BlockFactory, Integer, ScoreVectorBuilder> scoreVectorBuilderSupplier
58+
ShardConfig[] shards
6059
) {
6160
this.blockFactory = blockFactory;
6261
this.shards = shards;
63-
this.scoreVectorBuilderSupplier = scoreVectorBuilderSupplier;
62+
this.perShardState = new ArrayList<>(Collections.nCopies(shards.length, null));
6463
}
6564

6665
public Block executeQuery(Page page) {
@@ -115,7 +114,7 @@ private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException
115114
int min = docs.docs().getInt(0);
116115
int max = docs.docs().getInt(docs.getPositionCount() - 1);
117116
int length = max - min + 1;
118-
try (ScoreVectorBuilder scoreBuilder = scoreVectorBuilderSupplier.apply(blockFactory, length)) {
117+
try (T scoreBuilder = createBuilder(blockFactory, length)) {
119118
if (length == docs.getPositionCount() && length > 1) {
120119
return segmentState.scoreDense(scoreBuilder, min, max);
121120
}
@@ -143,8 +142,7 @@ private Vector evalSlow(DocVector docs) throws IOException {
143142
int prevShard = -1;
144143
int prevSegment = -1;
145144
SegmentState segmentState = null;
146-
try (ScoreVectorBuilder scoreBuilder = scoreVectorBuilderSupplier.apply(blockFactory, docs.getPositionCount())) {
147-
scoreBuilder.initVector();
145+
try (T scoreBuilder = createBuilder(blockFactory, docs.getPositionCount())) {
148146
for (int i = 0; i < docs.getPositionCount(); i++) {
149147
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
150148
int segment = docs.segments().getInt(map[i]);
@@ -155,7 +153,7 @@ private Vector evalSlow(DocVector docs) throws IOException {
155153
prevSegment = segment;
156154
}
157155
if (segmentState.noMatch) {
158-
scoreBuilder.appendNoMatch();
156+
appendNoMatch(scoreBuilder);
159157
} else {
160158
segmentState.scoreSingleDocWithScorer(scoreBuilder, docs.docs().getInt(map[i]));
161159
}
@@ -170,40 +168,39 @@ private Vector evalSlow(DocVector docs) throws IOException {
170168
public void close() {
171169
}
172170

173-
protected abstract ScoreMode scoreMode();
174-
175171
private ShardState shardState(int shard) throws IOException {
176-
if (shard >= perShardState.length) {
177-
perShardState = ArrayUtil.grow(perShardState, shard + 1);
178-
} else if (perShardState[shard] != null) {
179-
return perShardState[shard];
172+
ShardState shardState = perShardState.get(shard);
173+
if (shardState != null) {
174+
return shardState;
180175
}
181-
perShardState[shard] = new ShardState(shards[shard]);
182-
return perShardState[shard];
176+
shardState = new ShardState(shards[shard]);
177+
perShardState.set(shard, shardState);
178+
return shardState;
183179
}
184180

185181
private class ShardState {
186182
private final Weight weight;
187183
private final IndexSearcher searcher;
188-
private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES;
184+
private final List<SegmentState> perSegmentState;
189185

190186
ShardState(ShardConfig config) throws IOException {
191187
weight = config.searcher.createWeight(config.query, scoreMode(), 1.0f);
192188
searcher = config.searcher;
189+
perSegmentState = new ArrayList<>(Collections.nCopies(searcher.getLeafContexts().size(), null));
193190
}
194191

195192
SegmentState segmentState(int segment) throws IOException {
196-
if (segment >= perSegmentState.length) {
197-
perSegmentState = ArrayUtil.grow(perSegmentState, segment + 1);
198-
} else if (perSegmentState[segment] != null) {
199-
return perSegmentState[segment];
193+
SegmentState segmentState = perSegmentState.get(segment);
194+
if (segmentState != null) {
195+
return segmentState;
200196
}
201-
perSegmentState[segment] = new SegmentState(weight, searcher.getLeafContexts().get(segment));
202-
return perSegmentState[segment];
197+
segmentState = new SegmentState(weight, searcher.getLeafContexts().get(segment));
198+
perSegmentState.set(segment, segmentState);
199+
return segmentState;
203200
}
204201
}
205202

206-
private static class SegmentState {
203+
private class SegmentState {
207204
private final Weight weight;
208205
private final LeafReaderContext ctx;
209206

@@ -244,9 +241,9 @@ private SegmentState(Weight weight, LeafReaderContext ctx) {
244241
* Score a range using the {@link BulkScorer}. This should be faster
245242
* than using {@link #scoreSparse} for dense doc ids.
246243
*/
247-
Vector scoreDense(ScoreVectorBuilder scoreBuilder, int min, int max) throws IOException {
244+
Vector scoreDense(T scoreBuilder, int min, int max) throws IOException {
248245
if (noMatch) {
249-
return scoreBuilder.createNoMatchVector();
246+
return createNoMatchVector(blockFactory, max - min + 1);
250247
}
251248
if (bulkScorer == null || // The bulkScorer wasn't initialized
252249
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
@@ -255,10 +252,12 @@ Vector scoreDense(ScoreVectorBuilder scoreBuilder, int min, int max) throws IOEx
255252
bulkScorer = weight.bulkScorer(ctx);
256253
if (bulkScorer == null) {
257254
noMatch = true;
258-
return scoreBuilder.createNoMatchVector();
255+
return createNoMatchVector(blockFactory, max - min + 1);
259256
}
260257
}
261-
try (DenseCollector collector = new DenseCollector(min, max, scoreBuilder)) {
258+
try (DenseCollector<T> collector = new DenseCollector<>(min, max, scoreBuilder,
259+
LuceneQueryEvaluator.this::appendNoMatch,
260+
LuceneQueryEvaluator.this::appendMatch)) {
262261
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
263262
return collector.build();
264263
}
@@ -268,12 +267,11 @@ Vector scoreDense(ScoreVectorBuilder scoreBuilder, int min, int max) throws IOEx
268267
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
269268
* doc ids it'd be faster to use {@link #scoreDense}.
270269
*/
271-
Vector scoreSparse(ScoreVectorBuilder scoreBuilder, IntVector docs) throws IOException {
270+
Vector scoreSparse(T scoreBuilder, IntVector docs) throws IOException {
272271
initScorer(docs.getInt(0));
273272
if (noMatch) {
274-
return scoreBuilder.createNoMatchVector();
273+
return createNoMatchVector(blockFactory, docs.getPositionCount());
275274
}
276-
scoreBuilder.initVector();
277275
for (int i = 0; i < docs.getPositionCount(); i++) {
278276
scoreSingleDocWithScorer(scoreBuilder, docs.getInt(i));
279277
}
@@ -296,41 +294,47 @@ private void initScorer(int minDocId) throws IOException {
296294
}
297295
}
298296

299-
private void scoreSingleDocWithScorer(ScoreVectorBuilder builder, int doc) throws IOException {
297+
private void scoreSingleDocWithScorer(T builder, int doc) throws IOException {
300298
if (scorer.iterator().docID() == doc) {
301-
builder.appendMatch(scorer);
299+
appendMatch(builder, scorer);
302300
} else if (scorer.iterator().docID() > doc) {
303-
builder.appendNoMatch();
301+
appendNoMatch(builder);
304302
} else {
305303
if (scorer.iterator().advance(doc) == doc) {
306-
builder.appendMatch(scorer);
304+
appendMatch(builder, scorer);
307305
} else {
308-
builder.appendNoMatch();
306+
appendNoMatch(builder);
309307
}
310308
}
311309
}
312310
}
313311

314-
private static final ShardState[] EMPTY_SHARD_STATES = new ShardState[0];
315-
private static final SegmentState[] EMPTY_SEGMENT_STATES = new SegmentState[0];
316-
317312
/**
318313
* Collects matching information for dense range of doc ids. This assumes that
319314
* doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
320315
* which isn't documented, but @jpountz swears is true.
321316
*/
322-
static class DenseCollector implements LeafCollector, Releasable {
323-
private final ScoreVectorBuilder scoreBuilder;
317+
static class DenseCollector<U extends Vector.Builder> implements LeafCollector, Releasable {
318+
private final U scoreBuilder;
324319
private final int max;
325-
private Scorable scorer;
320+
private final Consumer<U> appendNoMatch;
321+
private final CheckedBiConsumer<U, Scorable, IOException> appendMatch;
326322

323+
private Scorable scorer;
327324
int next;
328325

329-
DenseCollector(int min, int max, ScoreVectorBuilder scoreBuilder) {
326+
DenseCollector(
327+
int min,
328+
int max,
329+
U scoreBuilder,
330+
Consumer<U> appendNoMatch,
331+
CheckedBiConsumer<U, Scorable, IOException> appendMatch
332+
) {
330333
this.scoreBuilder = scoreBuilder;
331-
scoreBuilder.initVector();
332334
this.max = max;
333335
next = min;
336+
this.appendNoMatch = appendNoMatch;
337+
this.appendMatch = appendMatch;
334338
}
335339

336340
@Override
@@ -341,9 +345,9 @@ public void setScorer(Scorable scorable) {
341345
@Override
342346
public void collect(int doc) throws IOException {
343347
while (next++ < doc) {
344-
scoreBuilder.appendNoMatch();
348+
appendNoMatch.accept(scoreBuilder);
345349
}
346-
scoreBuilder.appendMatch(scorer);
350+
appendMatch.accept(scoreBuilder, scorer);
347351
}
348352

349353
public Vector build() {
@@ -353,7 +357,7 @@ public Vector build() {
353357
@Override
354358
public void finish() {
355359
while (next++ <= max) {
356-
scoreBuilder.appendNoMatch();
360+
appendNoMatch.accept(scoreBuilder);
357361
}
358362
}
359363

@@ -363,15 +367,13 @@ public void close() {
363367
}
364368
}
365369

366-
public interface ScoreVectorBuilder extends Releasable {
367-
Vector createNoMatchVector();
370+
protected abstract ScoreMode scoreMode();
368371

369-
void initVector();
372+
protected abstract Vector createNoMatchVector(BlockFactory blockFactory, int size);
370373

371-
void appendNoMatch();
374+
protected abstract T createBuilder(BlockFactory blockFactory, int size);
372375

373-
void appendMatch(Scorable scorer) throws IOException;
376+
protected abstract void appendNoMatch(T builder);
374377

375-
Vector build();
376-
}
378+
protected abstract void appendMatch(T builder, Scorable scorer) throws IOException;
377379
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java

Lines changed: 24 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import org.elasticsearch.compute.data.Vector;
1818
import org.elasticsearch.compute.operator.DriverContext;
1919
import org.elasticsearch.compute.operator.EvalOperator;
20-
import org.elasticsearch.core.Releasables;
20+
21+
import java.io.IOException;
2122

2223
/**
2324
* {@link EvalOperator.ExpressionEvaluator} to run a Lucene {@link Query} during
@@ -26,15 +27,15 @@
2627
* {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
2728
* this evaluator is here to save the day.
2829
*/
29-
public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator implements EvalOperator.ExpressionEvaluator {
30-
31-
public static final double NO_MATCH_SCORE = 0.0;
30+
public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator<BooleanVector.Builder>
31+
implements
32+
EvalOperator.ExpressionEvaluator {
3233

3334
LuceneQueryExpressionEvaluator(
3435
BlockFactory blockFactory,
3536
ShardConfig[] shards
3637
) {
37-
super(blockFactory, shards, BooleanScoreVectorBuilder::new);
38+
super(blockFactory, shards);
3839
}
3940

4041
@Override
@@ -47,63 +48,30 @@ protected ScoreMode scoreMode() {
4748
return ScoreMode.COMPLETE_NO_SCORES;
4849
}
4950

50-
public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
51-
private final ShardConfig[] shardConfigs;
52-
53-
public Factory(ShardConfig[] shardConfigs) {
54-
this.shardConfigs = shardConfigs;
55-
}
56-
57-
@Override
58-
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
59-
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs);
60-
}
51+
@Override
52+
protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
53+
return blockFactory.newConstantBooleanVector(false, size);
6154
}
6255

63-
static class BooleanScoreVectorBuilder implements ScoreVectorBuilder {
64-
65-
private final BlockFactory blockFactory;
66-
private final int size;
67-
68-
private BooleanVector.Builder builder;
69-
70-
BooleanScoreVectorBuilder(BlockFactory blockFactory, int size) {
71-
this.blockFactory = blockFactory;
72-
this.size = size;
73-
}
74-
75-
@Override
76-
public Vector createNoMatchVector() {
77-
return blockFactory.newConstantBooleanVector(false, size);
78-
}
79-
80-
@Override
81-
public void initVector() {
82-
assert builder == null : "initVector called twice";
83-
builder = blockFactory.newBooleanVectorFixedBuilder(size);
84-
}
85-
86-
@Override
87-
public void appendNoMatch() {
88-
assert builder != null : "appendNoMatch called before initVector";
89-
builder.appendBoolean(false);
90-
}
56+
@Override
57+
protected BooleanVector.Builder createBuilder(BlockFactory blockFactory, int size) {
58+
return blockFactory.newBooleanVectorFixedBuilder(size);
59+
}
9160

92-
@Override
93-
public void appendMatch(Scorable scorer) {
94-
assert builder != null : "appendMatch called before initVector";
95-
builder.appendBoolean(true);
96-
}
61+
@Override
62+
protected void appendNoMatch(BooleanVector.Builder builder) {
63+
builder.appendBoolean(false);
64+
}
9765

98-
@Override
99-
public Vector build() {
100-
assert builder != null : "build called before initVector";
101-
return builder.build();
102-
}
66+
@Override
67+
protected void appendMatch(BooleanVector.Builder builder, Scorable scorer) throws IOException {
68+
builder.appendBoolean(true);
69+
}
10370

71+
public record Factory(ShardConfig[] shardConfigs) implements EvalOperator.ExpressionEvaluator.Factory {
10472
@Override
105-
public void close() {
106-
Releasables.closeExpectNoException(builder);
73+
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
74+
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs);
10775
}
10876
}
10977
}

0 commit comments

Comments
 (0)