Skip to content

Commit 1b66a2b

Browse files
authored
ES|QL: Handle multi values in FUSE (#135448)
1 parent d4cd655 commit 1b66a2b

File tree

15 files changed

+386
-130
lines changed

15 files changed

+386
-130
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/LinearScoreEvalOperator.java

Lines changed: 118 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
import org.elasticsearch.common.io.stream.StreamInput;
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.compute.data.Block;
16+
import org.elasticsearch.compute.data.BlockUtils;
1617
import org.elasticsearch.compute.data.BytesRefBlock;
17-
import org.elasticsearch.compute.data.DoubleVector;
18-
import org.elasticsearch.compute.data.DoubleVectorBlock;
18+
import org.elasticsearch.compute.data.DoubleBlock;
1919
import org.elasticsearch.compute.data.Page;
2020
import org.elasticsearch.compute.operator.DriverContext;
2121
import org.elasticsearch.compute.operator.Operator;
22+
import org.elasticsearch.compute.operator.Warnings;
2223
import org.elasticsearch.core.Releasables;
2324
import org.elasticsearch.core.TimeValue;
2425
import org.elasticsearch.xcontent.XContentBuilder;
@@ -28,6 +29,7 @@
2829
import java.util.Collection;
2930
import java.util.Deque;
3031
import java.util.HashMap;
32+
import java.util.List;
3133
import java.util.Map;
3234

3335
/**
@@ -40,11 +42,26 @@
4042
*
4143
*/
4244
public class LinearScoreEvalOperator implements Operator {
43-
public record Factory(int discriminatorPosition, int scorePosition, LinearConfig linearConfig) implements OperatorFactory {
45+
public record Factory(
46+
int discriminatorPosition,
47+
int scorePosition,
48+
LinearConfig linearConfig,
49+
String sourceText,
50+
int sourceLine,
51+
int sourceColumn
52+
) implements OperatorFactory {
4453

4554
@Override
4655
public Operator get(DriverContext driverContext) {
47-
return new LinearScoreEvalOperator(discriminatorPosition, scorePosition, linearConfig);
56+
return new LinearScoreEvalOperator(
57+
driverContext,
58+
discriminatorPosition,
59+
scorePosition,
60+
linearConfig,
61+
sourceText,
62+
sourceLine,
63+
sourceColumn
64+
);
4865
}
4966

5067
@Override
@@ -74,11 +91,30 @@ public String describe() {
7491
private long rowsReceived = 0;
7592
private long rowsEmitted = 0;
7693

77-
public LinearScoreEvalOperator(int discriminatorPosition, int scorePosition, LinearConfig config) {
94+
private final String sourceText;
95+
private final int sourceLine;
96+
private final int sourceColumn;
97+
private Warnings warnings;
98+
private final DriverContext driverContext;
99+
100+
public LinearScoreEvalOperator(
101+
DriverContext driverContext,
102+
int discriminatorPosition,
103+
int scorePosition,
104+
LinearConfig config,
105+
String sourceText,
106+
int sourceLine,
107+
int sourceColumn
108+
) {
78109
this.scorePosition = scorePosition;
79110
this.discriminatorPosition = discriminatorPosition;
80111
this.config = config;
81112
this.normalizer = createNormalizer(config.normalizer());
113+
this.driverContext = driverContext;
114+
115+
this.sourceText = sourceText;
116+
this.sourceLine = sourceLine;
117+
this.sourceColumn = sourceColumn;
82118

83119
finished = false;
84120
inputPages = new ArrayDeque<>();
@@ -123,25 +159,54 @@ private void createOutputPages() {
123159

124160
private void processInputPage(Page inputPage) {
125161
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);
126-
DoubleVectorBlock initialScoreBlock = inputPage.getBlock(scorePosition);
162+
DoubleBlock initialScoreBlock = inputPage.getBlock(scorePosition);
127163

128164
Page newPage = null;
129165
Block scoreBlock = null;
130-
DoubleVector.Builder scores = null;
166+
DoubleBlock.Builder scores = null;
131167

132168
try {
133-
scores = discriminatorBlock.blockFactory().newDoubleVectorBuilder(discriminatorBlock.getPositionCount());
169+
scores = discriminatorBlock.blockFactory().newDoubleBlockBuilder(discriminatorBlock.getPositionCount());
134170

135171
for (int i = 0; i < inputPage.getPositionCount(); i++) {
136-
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();
172+
Object discriminatorValue = BlockUtils.toJavaObject(discriminatorBlock, i);
173+
174+
if (discriminatorValue == null) {
175+
warnings().registerException(new IllegalArgumentException("group column has null values; assigning null scores"));
176+
scores.appendNull();
177+
continue;
178+
} else if (discriminatorValue instanceof List<?>) {
179+
warnings().registerException(
180+
new IllegalArgumentException("group column contains multivalued entries; assigning null scores")
181+
);
182+
scores.appendNull();
183+
continue;
184+
}
185+
String discriminator = ((BytesRef) discriminatorValue).utf8ToString();
137186

138187
var weight = config.weights().get(discriminator) == null ? 1.0 : config.weights().get(discriminator);
139188

140-
double score = initialScoreBlock.getDouble(i);
189+
initialScoreBlock.doesHaveMultivaluedFields();
190+
191+
Object scoreValue = BlockUtils.toJavaObject(initialScoreBlock, i);
192+
if (scoreValue == null) {
193+
warnings().registerException(new IllegalArgumentException("score column has null values; assigning null scores"));
194+
scores.appendNull();
195+
continue;
196+
} else if (scoreValue instanceof List<?>) {
197+
warnings().registerException(
198+
new IllegalArgumentException("score column contains multivalued entries; assigning null scores")
199+
);
200+
scores.appendNull();
201+
continue;
202+
}
203+
204+
double score = (double) scoreValue;
205+
141206
scores.appendDouble(weight * normalizer.normalize(score, discriminator));
142207
}
143208

144-
scoreBlock = scores.build().asBlock();
209+
scoreBlock = scores.build();
145210
newPage = inputPage.appendBlock(scoreBlock);
146211

147212
int[] projections = new int[newPage.getBlockCount() - 1];
@@ -270,23 +335,43 @@ private Normalizer createNormalizer(LinearConfig.Normalizer normalizer) {
270335
};
271336
}
272337

273-
private interface Normalizer {
274-
double normalize(double score, String discriminator);
338+
private abstract static class Normalizer {
339+
abstract double normalize(double score, String discriminator);
275340

276-
void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition);
341+
abstract void preprocess(double score, String discriminator);
342+
343+
void finalizePreprocess() {};
344+
345+
void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
346+
for (Page inputPage : inputPages) {
347+
DoubleBlock scoreBlock = inputPage.getBlock(scorePosition);
348+
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);
349+
350+
for (int i = 0; i < inputPage.getPositionCount(); i++) {
351+
Object scoreValue = BlockUtils.toJavaObject(scoreBlock, i);
352+
Object discriminatorValue = BlockUtils.toJavaObject(discriminatorBlock, i);
353+
354+
if (scoreValue instanceof Double score && discriminatorValue instanceof BytesRef discriminator) {
355+
preprocess(score, discriminator.utf8ToString());
356+
}
357+
}
358+
}
359+
360+
finalizePreprocess();
361+
}
277362
}
278363

279-
private class NoneNormalizer implements Normalizer {
364+
private static class NoneNormalizer extends Normalizer {
280365
@Override
281366
public double normalize(double score, String discriminator) {
282367
return score;
283368
}
284369

285370
@Override
286-
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {}
371+
void preprocess(double score, String discriminator) {}
287372
}
288373

289-
private class L2NormNormalizer implements Normalizer {
374+
private static class L2NormNormalizer extends Normalizer {
290375
private final Map<String, Double> l2Norms = new HashMap<>();
291376

292377
@Override
@@ -297,24 +382,17 @@ public double normalize(double score, String discriminator) {
297382
}
298383

299384
@Override
300-
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
301-
for (Page inputPage : inputPages) {
302-
DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition);
303-
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);
304-
305-
for (int i = 0; i < inputPage.getPositionCount(); i++) {
306-
double score = scoreBlock.getDouble(i);
307-
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();
308-
309-
l2Norms.compute(discriminator, (k, v) -> v == null ? score * score : v + score * score);
310-
}
311-
}
385+
void preprocess(double score, String discriminator) {
386+
l2Norms.compute(discriminator, (k, v) -> v == null ? score * score : v + score * score);
387+
}
312388

389+
@Override
390+
void finalizePreprocess() {
313391
l2Norms.replaceAll((k, v) -> Math.sqrt(v));
314392
}
315393
}
316394

317-
private class MinMaxNormalizer implements Normalizer {
395+
private static class MinMaxNormalizer extends Normalizer {
318396
private final Map<String, Double> minScores = new HashMap<>();
319397
private final Map<String, Double> maxScores = new HashMap<>();
320398

@@ -334,19 +412,17 @@ public double normalize(double score, String discriminator) {
334412
}
335413

336414
@Override
337-
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
338-
for (Page inputPage : inputPages) {
339-
DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition);
340-
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);
341-
342-
for (int i = 0; i < inputPage.getPositionCount(); i++) {
343-
double score = scoreBlock.getDouble(i);
344-
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();
415+
void preprocess(double score, String discriminator) {
416+
minScores.compute(discriminator, (key, value) -> value == null ? score : Math.min(value, score));
417+
maxScores.compute(discriminator, (key, value) -> value == null ? score : Math.max(value, score));
418+
}
419+
}
345420

346-
minScores.compute(discriminator, (key, value) -> value == null ? score : Math.min(value, score));
347-
maxScores.compute(discriminator, (key, value) -> value == null ? score : Math.max(value, score));
348-
}
349-
}
421+
private Warnings warnings() {
422+
if (warnings == null) {
423+
this.warnings = Warnings.createWarnings(driverContext.warningsMode(), sourceLine, sourceColumn, sourceText);
350424
}
425+
426+
return warnings;
351427
}
352428
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/RrfScoreEvalOperator.java

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99

1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.compute.data.Block;
12+
import org.elasticsearch.compute.data.BlockUtils;
1213
import org.elasticsearch.compute.data.BytesRefBlock;
13-
import org.elasticsearch.compute.data.DoubleVector;
14+
import org.elasticsearch.compute.data.DoubleBlock;
1415
import org.elasticsearch.compute.data.Page;
1516
import org.elasticsearch.compute.operator.AbstractPageMappingOperator;
1617
import org.elasticsearch.compute.operator.DriverContext;
1718
import org.elasticsearch.compute.operator.Operator;
19+
import org.elasticsearch.compute.operator.Warnings;
1820
import org.elasticsearch.core.Releasables;
1921

2022
import java.util.HashMap;
23+
import java.util.List;
2124

2225
/**
2326
* Updates the score column with new scores using the RRF formula.
@@ -27,10 +30,25 @@
2730
*/
2831
public class RrfScoreEvalOperator extends AbstractPageMappingOperator {
2932

30-
public record Factory(int discriminatorPosition, int scorePosition, RrfConfig rrfConfig) implements OperatorFactory {
33+
public record Factory(
34+
int discriminatorPosition,
35+
int scorePosition,
36+
RrfConfig rrfConfig,
37+
String sourceText,
38+
int sourceLine,
39+
int sourceColumn
40+
) implements OperatorFactory {
3141
@Override
3242
public Operator get(DriverContext driverContext) {
33-
return new RrfScoreEvalOperator(discriminatorPosition, scorePosition, rrfConfig);
43+
return new RrfScoreEvalOperator(
44+
driverContext,
45+
discriminatorPosition,
46+
scorePosition,
47+
rrfConfig,
48+
sourceText,
49+
sourceLine,
50+
sourceColumn
51+
);
3452
}
3553

3654
@Override
@@ -48,37 +66,62 @@ public String describe() {
4866
private final int scorePosition;
4967
private final int discriminatorPosition;
5068
private final RrfConfig config;
69+
private Warnings warnings;
70+
private final DriverContext driverContext;
71+
private final String sourceText;
72+
private final int sourceLine;
73+
private final int sourceColumn;
5174

5275
private HashMap<String, Integer> counters = new HashMap<>();
5376

54-
public RrfScoreEvalOperator(int discriminatorPosition, int scorePosition, RrfConfig config) {
77+
public RrfScoreEvalOperator(
78+
DriverContext driverContext,
79+
int discriminatorPosition,
80+
int scorePosition,
81+
RrfConfig config,
82+
String sourceText,
83+
int sourceLine,
84+
int sourceColumn
85+
) {
5586
this.scorePosition = scorePosition;
5687
this.discriminatorPosition = discriminatorPosition;
5788
this.config = config;
89+
this.driverContext = driverContext;
90+
this.sourceText = sourceText;
91+
this.sourceLine = sourceLine;
92+
this.sourceColumn = sourceColumn;
5893
}
5994

6095
@Override
6196
protected Page process(Page page) {
62-
BytesRefBlock discriminatorBlock = (BytesRefBlock) page.getBlock(discriminatorPosition);
63-
64-
DoubleVector.Builder scores = discriminatorBlock.blockFactory().newDoubleVectorBuilder(discriminatorBlock.getPositionCount());
97+
BytesRefBlock discriminatorBlock = page.getBlock(discriminatorPosition);
98+
DoubleBlock.Builder scores = discriminatorBlock.blockFactory().newDoubleBlockBuilder(discriminatorBlock.getPositionCount());
6599

66100
for (int i = 0; i < page.getPositionCount(); i++) {
67-
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();
68-
69-
int rank = counters.getOrDefault(discriminator, 1);
70-
counters.put(discriminator, rank + 1);
71-
72-
var weight = config.weights().getOrDefault(discriminator, 1.0);
73-
74-
scores.appendDouble(1.0 / (config.rankConstant() + rank) * weight);
101+
Object value = BlockUtils.toJavaObject(discriminatorBlock, i);
102+
103+
if (value == null) {
104+
warnings().registerException(new IllegalArgumentException("group column has null values; assigning null scores"));
105+
scores.appendNull();
106+
} else if (value instanceof List<?>) {
107+
warnings().registerException(
108+
new IllegalArgumentException("group column contains multivalued entries; assigning null scores")
109+
);
110+
scores.appendNull();
111+
} else {
112+
String discriminator = ((BytesRef) value).utf8ToString();
113+
int rank = counters.getOrDefault(discriminator, 1);
114+
var weight = config.weights().getOrDefault(discriminator, 1.0);
115+
scores.appendDouble(1.0 / (config.rankConstant() + rank) * weight);
116+
counters.put(discriminator, rank + 1);
117+
}
75118
}
76119

77120
Page newPage = null;
78121
Block scoreBlock = null;
79122

80123
try {
81-
scoreBlock = scores.build().asBlock();
124+
scoreBlock = scores.build();
82125
newPage = page.appendBlock(scoreBlock);
83126

84127
int[] projections = new int[newPage.getBlockCount() - 1];
@@ -105,4 +148,12 @@ protected Page process(Page page) {
105148
public String toString() {
106149
return "RrfScoreEvalOperator";
107150
}
151+
152+
private Warnings warnings() {
153+
if (warnings == null) {
154+
this.warnings = Warnings.createWarnings(driverContext.warningsMode(), sourceLine, sourceColumn, sourceText);
155+
}
156+
157+
return warnings;
158+
}
108159
}

0 commit comments

Comments
 (0)