Skip to content

Commit f76f0a7

Browse files
committed
Refactor LuceneQueryExpressionEvaluator to make scoring behaviour and vector returned pluggable
1 parent 1c34a05 commit f76f0a7

File tree

8 files changed

+149
-69
lines changed

8 files changed

+149
-69
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java

Lines changed: 107 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.compute.data.DocVector;
2626
import org.elasticsearch.compute.data.IntVector;
2727
import org.elasticsearch.compute.data.Page;
28+
import org.elasticsearch.compute.data.Vector;
2829
import org.elasticsearch.compute.operator.DriverContext;
2930
import org.elasticsearch.compute.operator.EvalOperator;
3031
import org.elasticsearch.core.Releasable;
@@ -45,12 +46,14 @@ public record ShardConfig(Query query, IndexSearcher searcher) {}
4546

4647
private final BlockFactory blockFactory;
4748
private final ShardConfig[] shards;
49+
private final DocScorerVectorProvider docScorerVectorProvider;
4850

4951
private ShardState[] perShardState = EMPTY_SHARD_STATES;
5052

51-
public LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
53+
LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards, DocScorerVectorProvider docScorerVectorProvider) {
5254
this.blockFactory = blockFactory;
5355
this.shards = shards;
56+
this.docScorerVectorProvider = docScorerVectorProvider;
5457
}
5558

5659
@Override
@@ -65,6 +68,8 @@ public Block eval(Page page) {
6568
}
6669
} catch (IOException e) {
6770
throw new UncheckedIOException(e);
71+
} finally {
72+
Releasables.closeExpectNoException(docScorerVectorProvider);
6873
}
6974
}
7075

@@ -98,7 +103,7 @@ public Block eval(Page page) {
98103
* common.
99104
* </p>
100105
*/
101-
private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
106+
private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
102107
ShardState shardState = shardState(docs.shards().getInt(0));
103108
SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
104109
int min = docs.docs().getInt(0);
@@ -124,32 +129,31 @@ private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOEx
124129
* the order that the {@link DocVector} came in.
125130
* </p>
126131
*/
127-
private BooleanVector evalSlow(DocVector docs) throws IOException {
132+
private Vector evalSlow(DocVector docs) throws IOException {
128133
int[] map = docs.shardSegmentDocMapForwards();
129134
// Clear any state flags from the previous run
130135
int prevShard = -1;
131136
int prevSegment = -1;
132137
SegmentState segmentState = null;
133-
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
134-
for (int i = 0; i < docs.getPositionCount(); i++) {
135-
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
136-
int segment = docs.segments().getInt(map[i]);
137-
if (segmentState == null || prevShard != shard || prevSegment != segment) {
138-
segmentState = shardState(shard).segmentState(segment);
139-
segmentState.initScorer(docs.docs().getInt(map[i]));
140-
prevShard = shard;
141-
prevSegment = segment;
142-
}
143-
if (segmentState.noMatch) {
144-
builder.appendBoolean(false);
145-
} else {
146-
segmentState.scoreSingleDocWithScorer(builder, docs.docs().getInt(map[i]));
147-
}
138+
docScorerVectorProvider.init(docs.getPositionCount());
139+
for (int i = 0; i < docs.getPositionCount(); i++) {
140+
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
141+
int segment = docs.segments().getInt(map[i]);
142+
if (segmentState == null || prevShard != shard || prevSegment != segment) {
143+
segmentState = shardState(shard).segmentState(segment);
144+
segmentState.initScorer(docs.docs().getInt(map[i]));
145+
prevShard = shard;
146+
prevSegment = segment;
148147
}
149-
try (BooleanVector outOfOrder = builder.build()) {
150-
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
148+
if (segmentState.noMatch) {
149+
docScorerVectorProvider.scoreNoHit();
150+
} else {
151+
segmentState.scoreSingleDocWithScorer(docs.docs().getInt(map[i]));
151152
}
152153
}
154+
try (Vector outOfOrder = docScorerVectorProvider.build()) {
155+
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
156+
}
153157
}
154158

155159
@Override
@@ -244,7 +248,7 @@ BooleanVector scoreDense(int min, int max) throws IOException {
244248
return blockFactory.newConstantBooleanVector(false, length);
245249
}
246250
}
247-
try (DenseCollector collector = new DenseCollector(blockFactory, min, max)) {
251+
try (DenseCollector collector = new DenseCollector(blockFactory, docScorerVectorProvider, min, max)) {
248252
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
249253
return collector.build();
250254
}
@@ -254,17 +258,16 @@ BooleanVector scoreDense(int min, int max) throws IOException {
254258
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
255259
* doc ids it'd be faster to use {@link #scoreDense}.
256260
*/
257-
BooleanVector scoreSparse(IntVector docs) throws IOException {
261+
Vector scoreSparse(IntVector docs) throws IOException {
258262
initScorer(docs.getInt(0));
259263
if (noMatch) {
260-
return blockFactory.newConstantBooleanVector(false, docs.getPositionCount());
264+
return docScorerVectorProvider.noneMatch(docs.getPositionCount());
261265
}
262-
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
263-
for (int i = 0; i < docs.getPositionCount(); i++) {
264-
scoreSingleDocWithScorer(builder, docs.getInt(i));
265-
}
266-
return builder.build();
266+
docScorerVectorProvider.init(docs.getPositionCount());
267+
for (int i = 0; i < docs.getPositionCount(); i++) {
268+
scoreSingleDocWithScorer(docs.getInt(i));
267269
}
270+
return docScorerVectorProvider.build();
268271
}
269272

270273
private void initScorer(int minDocId) throws IOException {
@@ -283,13 +286,17 @@ private void initScorer(int minDocId) throws IOException {
283286
}
284287
}
285288

286-
private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) throws IOException {
289+
private void scoreSingleDocWithScorer(int doc) throws IOException {
287290
if (scorer.iterator().docID() == doc) {
288-
builder.appendBoolean(true);
291+
docScorerVectorProvider.scoreHit(scorer);
289292
} else if (scorer.iterator().docID() > doc) {
290-
builder.appendBoolean(false);
293+
docScorerVectorProvider.scoreNoHit();
291294
} else {
292-
builder.appendBoolean(scorer.iterator().advance(doc) == doc);
295+
if (scorer.iterator().advance(doc) == doc) {
296+
docScorerVectorProvider.scoreHit(scorer);
297+
} else {
298+
docScorerVectorProvider.scoreNoHit();
299+
}
293300
}
294301
}
295302
}
@@ -305,17 +312,22 @@ private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) th
305312
static class DenseCollector implements LeafCollector, Releasable {
306313
private final BooleanVector.FixedBuilder builder;
307314
private final int max;
315+
private final DocScorerVectorProvider docScorerVectorProvider;
316+
private Scorable scorable;
308317

309318
int next;
310319

311-
DenseCollector(BlockFactory blockFactory, int min, int max) {
320+
DenseCollector(BlockFactory blockFactory, DocScorerVectorProvider docScorerVectorProvider, int min, int max) {
312321
this.builder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1);
313322
this.max = max;
323+
this.docScorerVectorProvider = docScorerVectorProvider;
314324
next = min;
315325
}
316326

317327
@Override
318-
public void setScorer(Scorable scorable) {}
328+
public void setScorer(Scorable scorable) {
329+
this.scorable = scorable;
330+
}
319331

320332
@Override
321333
public void collect(int doc) {
@@ -342,6 +354,62 @@ public void close() {
342354
}
343355
}
344356

357+
private interface DocScorerVectorProvider extends Releasable {
358+
359+
Vector noneMatch(int docs);
360+
361+
void init(int numDocs);
362+
363+
void scoreHit(Scorable scorable);
364+
365+
void scoreNoHit();
366+
367+
Vector build();
368+
}
369+
370+
static class NonScoringDocScorerVectorProvider implements DocScorerVectorProvider {
371+
372+
private final BlockFactory blockFactory;
373+
private BooleanVector.Builder builder;
374+
375+
NonScoringDocScorerVectorProvider(BlockFactory blockFactory) {
376+
this.blockFactory = blockFactory;
377+
}
378+
379+
@Override
380+
public Vector noneMatch(int docs) {
381+
return blockFactory.newConstantBooleanVector(false, docs);
382+
}
383+
384+
@Override
385+
public void init(int numDocs) {
386+
builder = blockFactory.newBooleanVectorFixedBuilder(numDocs);
387+
}
388+
389+
@Override
390+
public void scoreHit(Scorable scorable) {
391+
assert builder != null : "init must be called before scoring";
392+
builder.appendBoolean(true);
393+
}
394+
395+
@Override
396+
public void scoreNoHit() {
397+
assert builder != null : "init must be called before scoring";
398+
builder.appendBoolean(false);
399+
}
400+
401+
@Override
402+
public Vector build() {
403+
assert builder != null : "init must be called before scoring";
404+
return builder.build();
405+
}
406+
407+
@Override
408+
public void close() {
409+
Releasables.closeExpectNoException(builder);
410+
}
411+
}
412+
345413
public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
346414
private final ShardConfig[] shardConfigs;
347415

@@ -351,7 +419,11 @@ public Factory(ShardConfig[] shardConfigs) {
351419

352420
@Override
353421
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
354-
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs);
422+
return new LuceneQueryExpressionEvaluator(
423+
context.blockFactory(),
424+
shardConfigs,
425+
new NonScoringDocScorerVectorProvider(context.blockFactory())
426+
);
355427
}
356428
}
357429
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluatorTests.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,14 @@ public class LuceneQueryExpressionEvaluatorTests extends ComputeTestCase {
5858
private static final String FIELD = "g";
5959

6060
public void testDenseCollectorSmall() {
61-
try (DenseCollector collector = new DenseCollector(blockFactory(), 0, 2)) {
61+
try (
62+
DenseCollector collector = new DenseCollector(
63+
blockFactory(),
64+
new LuceneQueryExpressionEvaluator.NonScoringDocScorerVectorProvider(blockFactory()),
65+
0,
66+
2
67+
)
68+
) {
6269
collector.collect(0);
6370
collector.collect(1);
6471
collector.collect(2);
@@ -72,7 +79,14 @@ public void testDenseCollectorSmall() {
7279
}
7380

7481
public void testDenseCollectorSimple() {
75-
try (DenseCollector collector = new DenseCollector(blockFactory(), 0, 10)) {
82+
try (
83+
DenseCollector collector = new DenseCollector(
84+
blockFactory(),
85+
new LuceneQueryExpressionEvaluator.NonScoringDocScorerVectorProvider(blockFactory()),
86+
0,
87+
10
88+
)
89+
) {
7690
collector.collect(2);
7791
collector.collect(5);
7892
collector.finish();
@@ -89,7 +103,14 @@ public void testDenseCollector() {
89103
int min = between(0, Integer.MAX_VALUE - length - 1);
90104
int max = min + length + 1;
91105
boolean[] expected = new boolean[length];
92-
try (DenseCollector collector = new DenseCollector(blockFactory(), min, max)) {
106+
try (
107+
DenseCollector collector = new DenseCollector(
108+
blockFactory(),
109+
new LuceneQueryExpressionEvaluator.NonScoringDocScorerVectorProvider(blockFactory()),
110+
min,
111+
max
112+
)
113+
) {
93114
for (int i = 0; i < length; i++) {
94115
expected[i] = randomBoolean();
95116
if (expected[i]) {
@@ -183,8 +204,8 @@ private List<Page> runQuery(Set<String> values, Query query, boolean shuffleDocs
183204
);
184205
LuceneQueryExpressionEvaluator luceneQueryEvaluator = new LuceneQueryExpressionEvaluator(
185206
blockFactory,
186-
new LuceneQueryExpressionEvaluator.ShardConfig[] { shard }
187-
207+
new LuceneQueryExpressionEvaluator.ShardConfig[] { shard },
208+
new LuceneQueryExpressionEvaluator.NonScoringDocScorerVectorProvider(blockFactory)
188209
);
189210

190211
List<Operator> operators = new ArrayList<>();

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/MatchFunctionIT.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.action.index.IndexRequest;
1212
import org.elasticsearch.action.support.WriteRequest;
1313
import org.elasticsearch.common.settings.Settings;
14-
import org.elasticsearch.xpack.esql.EsqlTestUtils;
1514
import org.elasticsearch.xpack.esql.VerificationException;
1615
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
1716
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
@@ -285,12 +284,12 @@ public void testDisjunctionScoring() {
285284
assertThat(values.get(2).get(0), equalTo(6));
286285

287286
// Matches full text query and non pushable query
288-
assertThat((Double)values.get(0).get(1), greaterThan(2.0));
287+
assertThat((Double) values.get(0).get(1), greaterThan(2.0));
289288
// Matches just non pushable query
290-
assertThat((Double)values.get(1).get(1), equalTo(1));
289+
assertThat((Double) values.get(1).get(1), equalTo(1.0));
291290
// Matches just full text query
292-
assertThat((Double)values.get(2).get(1), lessThan(1.0));
293-
assertThat((Double)values.get(2).get(1), greaterThan(0.0));
291+
assertThat((Double) values.get(2).get(1), lessThan(1.0));
292+
assertThat((Double) values.get(2).get(1), greaterThan(0.0));
294293
}
295294
}
296295

@@ -312,11 +311,11 @@ public void testDisjunctionScoringMultipleNonPushableFunctions() {
312311
assertThat(values.get(1).get(0), equalTo(6));
313312

314313
// Matches the full text query and a non pushable query
315-
assertThat((Double)values.get(0).get(1), greaterThan(1.0));
316-
assertThat((Double)values.get(0).get(1), lessThan(2.0));
314+
assertThat((Double) values.get(0).get(1), greaterThan(1.0));
315+
assertThat((Double) values.get(0).get(1), lessThan(2.0));
317316
// Matches just the match function
318-
assertThat((Double)values.get(1).get(1), lessThan(1.0));
319-
assertThat((Double)values.get(1).get(1), greaterThan(0.0));
317+
assertThat((Double) values.get(1).get(1), lessThan(1.0));
318+
assertThat((Double) values.get(1).get(1), greaterThan(0.0));
320319
}
321320
}
322321

@@ -340,11 +339,11 @@ public void testScoresAreSimilarToQstr() {
340339
var matchValues = getValuesList(respMatch);
341340
var qstrValues = getValuesList(respQstr);
342341
assertEquals(matchValues.size(), qstrValues.size());
343-
for(int i = 0 ; i < matchValues.size(); i++) {
342+
for (int i = 0; i < matchValues.size(); i++) {
344343
// Compare ids
345344
assertEquals(matchValues.get(i).get(0), qstrValues.get(i).get(0));
346345
// Compare scores
347-
assertThat((Double)matchValues.get(i).get(1), closeTo((Double) qstrValues.get(i).get(1), 0.01));
346+
assertThat((Double) matchValues.get(i).get(1), closeTo((Double) qstrValues.get(i).get(1), 0.01));
348347
}
349348
}
350349
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import org.elasticsearch.xpack.esql.evaluator.mapper.BooleanToScoringExpressionEvaluator;
3131
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
3232
import org.elasticsearch.xpack.esql.evaluator.mapper.ExpressionMapper;
33-
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
3433
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
3534
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
3635
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryScoringLogic;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
3838
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
3939
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
40-
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
4140
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
4241
import org.elasticsearch.xpack.esql.querydsl.query.TranslationAwareExpressionQuery;
4342

@@ -228,9 +227,9 @@ private static void checkFullTextQueryFunctions(LogicalPlan plan, Failures failu
228227
);
229228
checkFullTextFunctionsParents(condition, failures);
230229

231-
// if (PlannerUtils.usesScoring(plan)) {
232-
// checkFullTextSearchDisjunctions(condition, failures);
233-
// }
230+
// if (PlannerUtils.usesScoring(plan)) {
231+
// checkFullTextSearchDisjunctions(condition, failures);
232+
// }
234233
} else {
235234
plan.forEachExpression(FullTextFunction.class, ftf -> {
236235
failures.add(fail(ftf, "[{}] {} is only supported in WHERE commands", ftf.functionName(), ftf.functionType()));

0 commit comments

Comments
 (0)