Skip to content

Commit ef688f7

Browse files
committed
Testing that _score is an existing column
1 parent b1d9f4a commit ef688f7

File tree

8 files changed

+144
-44
lines changed

8 files changed

+144
-44
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
2828
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
2929
import org.elasticsearch.xpack.esql.core.expression.Literal;
30+
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
3031
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
3132
import org.elasticsearch.xpack.esql.core.expression.Nullability;
3233
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
@@ -736,16 +737,36 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
736737
return new Fork(fork.source(), fork.child(), newSubPlans);
737738
}
738739

739-
private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childOutput) {
740+
private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childrenOutput) {
740741
List<Alias> newFields = new ArrayList<>();
741742
boolean changed = false;
743+
744+
// First resolving fields used in expression
742745
for (Alias field : rerank.rerankFields()) {
743-
Alias result = (Alias) field.transformUp(UnresolvedAttribute.class, ua -> resolveAttribute(ua, childOutput));
746+
Alias result = (Alias) field.transformUp(UnresolvedAttribute.class, ua -> resolveAttribute(ua, childrenOutput));
744747
newFields.add(result);
745748
changed |= result != field;
746749
}
747750

748-
return changed ? new Rerank(rerank.source(), rerank.child(), rerank.inferenceId(), rerank.queryText(), newFields) : rerank;
751+
if (changed) {
752+
rerank = rerank.withRerankFields(newFields);
753+
}
754+
755+
// Ensure the score attribute is resolved
756+
if (rerank.scoreAttribute() instanceof UnresolvedAttribute ua) {
757+
Attribute resolved = resolveAttribute(ua, childrenOutput);
758+
if (resolved.resolved() == false) {
759+
resolved = ua.withUnresolvedMessage(format(null, "Missing required column [{}] for RERANK", MetadataAttribute.SCORE));
760+
} else if (resolved.dataType() != DOUBLE) {
761+
resolved = ua.withUnresolvedMessage(
762+
format("_score has the wrong type; [{}] expected but got [{}]", DataType.DOUBLE, resolved.dataType())
763+
);
764+
}
765+
766+
rerank = rerank.withScoreAttribute(resolved);
767+
}
768+
769+
return rerank;
749770
}
750771

751772
private List<Attribute> resolveUsingColumns(List<Attribute> cols, List<Attribute> output, String side) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,29 @@
3838
public class Rerank extends InferencePlan implements SortAgnostic, SurrogateLogicalPlan {
3939

4040
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
41+
private final Attribute scoreAttribute;
4142
private final Expression queryText;
4243
private final List<Alias> rerankFields;
4344

4445
public Rerank(Source source, LogicalPlan child, Expression inferenceId, Expression queryText, List<Alias> rerankFields) {
4546
super(source, child, inferenceId);
4647
this.queryText = queryText;
4748
this.rerankFields = rerankFields;
49+
this.scoreAttribute = new UnresolvedAttribute(source, MetadataAttribute.SCORE);
50+
}
51+
52+
public Rerank(
53+
Source source,
54+
LogicalPlan child,
55+
Expression inferenceId,
56+
Expression queryText,
57+
List<Alias> rerankFields,
58+
Attribute scoreAttribute
59+
) {
60+
super(source, child, inferenceId);
61+
this.queryText = queryText;
62+
this.rerankFields = rerankFields;
63+
this.scoreAttribute = scoreAttribute;
4864
}
4965

5066
public Rerank(StreamInput in) throws IOException {
@@ -53,7 +69,8 @@ public Rerank(StreamInput in) throws IOException {
5369
in.readNamedWriteable(LogicalPlan.class),
5470
in.readNamedWriteable(Expression.class),
5571
in.readNamedWriteable(Expression.class),
56-
in.readCollectionAsList(Alias::new)
72+
in.readCollectionAsList(Alias::new),
73+
in.readNamedWriteable(Attribute.class)
5774
);
5875
}
5976

@@ -62,6 +79,7 @@ public void writeTo(StreamOutput out) throws IOException {
6279
super.writeTo(out);
6380
out.writeNamedWriteable(queryText);
6481
out.writeCollection(rerankFields());
82+
out.writeNamedWriteable(scoreAttribute);
6583
}
6684

6785
public Expression queryText() {
@@ -72,15 +90,27 @@ public List<Alias> rerankFields() {
7290
return rerankFields;
7391
}
7492

93+
public Attribute scoreAttribute() {
94+
return scoreAttribute;
95+
}
96+
7597
@Override
7698
public TaskType taskType() {
7799
return TaskType.RERANK;
78100
}
79101

80102
@Override
81-
public LogicalPlan withInferenceResolutionError(String inferenceId, String error) {
103+
public Rerank withInferenceResolutionError(String inferenceId, String error) {
82104
Expression newInferenceId = new UnresolvedAttribute(inferenceId().source(), inferenceId, error);
83-
return new Rerank(source(), child(), newInferenceId, queryText, rerankFields);
105+
return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute);
106+
}
107+
108+
public Rerank withRerankFields(List<Alias> newRerankFields) {
109+
return new Rerank(source(), child(), inferenceId(), queryText, newRerankFields, scoreAttribute);
110+
}
111+
112+
public Rerank withScoreAttribute(Attribute newScoreAttribute) {
113+
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, newScoreAttribute);
84114
}
85115

86116
@Override
@@ -90,12 +120,14 @@ public String getWriteableName() {
90120

91121
@Override
92122
public UnaryPlan replaceChild(LogicalPlan newChild) {
93-
return new Rerank(source(), newChild, inferenceId(), queryText, rerankFields);
123+
return new Rerank(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
94124
}
95125

96126
@Override
97127
protected AttributeSet computeReferences() {
98-
return computeReferences(rerankFields);
128+
AttributeSet refs = computeReferences(rerankFields);
129+
refs.add(scoreAttribute);
130+
return refs;
99131
}
100132

101133
public static AttributeSet computeReferences(List<Alias> fields) {
@@ -105,12 +137,12 @@ public static AttributeSet computeReferences(List<Alias> fields) {
105137

106138
@Override
107139
public boolean expressionsResolved() {
108-
return super.expressionsResolved() && queryText.resolved() && Resolvables.resolved(rerankFields);
140+
return super.expressionsResolved() && queryText.resolved() && Resolvables.resolved(rerankFields) && scoreAttribute.resolved();
109141
}
110142

111143
@Override
112144
protected NodeInfo<? extends LogicalPlan> info() {
113-
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields);
145+
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
114146
}
115147

116148
@Override
@@ -119,18 +151,18 @@ public boolean equals(Object o) {
119151
if (o == null || getClass() != o.getClass()) return false;
120152
if (super.equals(o) == false) return false;
121153
Rerank rerank = (Rerank) o;
122-
return Objects.equals(queryText, rerank.queryText) && Objects.equals(rerankFields, rerank.rerankFields);
154+
return Objects.equals(queryText, rerank.queryText)
155+
&& Objects.equals(rerankFields, rerank.rerankFields)
156+
&& Objects.equals(scoreAttribute, rerank.scoreAttribute);
123157
}
124158

125159
@Override
126160
public int hashCode() {
127-
return Objects.hash(super.hashCode(), queryText, rerankFields);
161+
return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
128162
}
129163

130164
@Override
131165
public LogicalPlan surrogate() {
132-
Attribute scoreAttribute = child().output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get();
133-
assert scoreAttribute != null;
134166
Order sortOrder = new Order(source(), scoreAttribute, Order.OrderDirection.DESC, Order.NullsPosition.ANY);
135167
return new OrderBy(source(), this, List.of(sortOrder));
136168
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.xpack.esql.core.expression.Alias;
14+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1415
import org.elasticsearch.xpack.esql.core.expression.Expression;
1516
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1617
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -32,11 +33,20 @@ public class RerankExec extends InferenceExec {
3233

3334
private final Expression queryText;
3435
private final List<Alias> rerankFields;
35-
36-
public RerankExec(Source source, PhysicalPlan child, Expression inferenceId, Expression queryText, List<Alias> rerankFields) {
36+
private final Attribute scoreAttribute;
37+
38+
public RerankExec(
39+
Source source,
40+
PhysicalPlan child,
41+
Expression inferenceId,
42+
Expression queryText,
43+
List<Alias> rerankFields,
44+
Attribute scoreAttribute
45+
) {
3746
super(source, child, inferenceId);
3847
this.queryText = queryText;
3948
this.rerankFields = rerankFields;
49+
this.scoreAttribute = scoreAttribute;
4050
}
4151

4252
public RerankExec(StreamInput in) throws IOException {
@@ -45,7 +55,8 @@ public RerankExec(StreamInput in) throws IOException {
4555
in.readNamedWriteable(PhysicalPlan.class),
4656
in.readNamedWriteable(Expression.class),
4757
in.readNamedWriteable(Expression.class),
48-
in.readCollectionAsList(Alias::new)
58+
in.readCollectionAsList(Alias::new),
59+
in.readNamedWriteable(Attribute.class)
4960
);
5061
}
5162

@@ -57,6 +68,10 @@ public List<Alias> rerankFields() {
5768
return rerankFields;
5869
}
5970

71+
public Attribute scoreAttribute() {
72+
return scoreAttribute;
73+
}
74+
6075
@Override
6176
public String getWriteableName() {
6277
return ENTRY.name;
@@ -67,16 +82,17 @@ public void writeTo(StreamOutput out) throws IOException {
6782
super.writeTo(out);
6883
out.writeNamedWriteable(queryText());
6984
out.writeCollection(rerankFields());
85+
out.writeNamedWriteable(scoreAttribute);
7086
}
7187

7288
@Override
7389
protected NodeInfo<? extends PhysicalPlan> info() {
74-
return NodeInfo.create(this, RerankExec::new, child(), inferenceId(), queryText, rerankFields);
90+
return NodeInfo.create(this, RerankExec::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
7591
}
7692

7793
@Override
7894
public UnaryExec replaceChild(PhysicalPlan newChild) {
79-
return new RerankExec(source(), newChild, inferenceId(), queryText, rerankFields);
95+
return new RerankExec(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
8096
}
8197

8298
@Override
@@ -85,11 +101,13 @@ public boolean equals(Object o) {
85101
if (o == null || getClass() != o.getClass()) return false;
86102
if (super.equals(o) == false) return false;
87103
RerankExec rerank = (RerankExec) o;
88-
return Objects.equals(queryText, rerank.queryText) && Objects.equals(rerankFields, rerank.rerankFields);
104+
return Objects.equals(queryText, rerank.queryText)
105+
&& Objects.equals(rerankFields, rerank.rerankFields)
106+
&& Objects.equals(scoreAttribute, rerank.scoreAttribute);
89107
}
90108

91109
@Override
92110
public int hashCode() {
93-
return Objects.hash(super.hashCode(), queryText, rerankFields);
111+
return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
94112
}
95113
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -568,14 +568,7 @@ private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerCon
568568

569569
String inferenceId = BytesRefs.toString(rerank.inferenceId().fold(context.foldCtx));
570570
String queryText = BytesRefs.toString(rerank.queryText().fold(context.foldCtx));
571-
572-
int scoreChannel = -1;
573-
574-
for (Attribute attr : rerank.output()) {
575-
if (attr.name().equals(MetadataAttribute.SCORE)) {
576-
scoreChannel = source.layout.get(attr.id()).channel();
577-
}
578-
}
571+
int scoreChannel = source.layout.get(rerank.scoreAttribute().id()).channel();
579572

580573
return source.with(
581574
new RerankOperator.Factory(inferenceService, inferenceId, queryText, rerankFieldsEvaluatorSuppliers, scoreChannel),

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,14 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) {
8787
}
8888

8989
if (p instanceof Rerank rerank) {
90-
return new RerankExec(rerank.source(), child, rerank.inferenceId(), rerank.queryText(), rerank.rerankFields());
90+
return new RerankExec(
91+
rerank.source(),
92+
child,
93+
rerank.inferenceId(),
94+
rerank.queryText(),
95+
rerank.rerankFields(),
96+
rerank.scoreAttribute()
97+
);
9198
}
9299

93100
if (p instanceof Enrich enrich) {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3280,7 +3280,7 @@ private void assertEmptyEsRelation(LogicalPlan plan) {
32803280
public void testResolveRerankInferenceId() {
32813281
{
32823282
LogicalPlan plan = analyze(
3283-
"FROM books | RERANK \"italian food recipe\" ON title WITH \"reranking-inference-id\"",
3283+
" FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH \"reranking-inference-id\"",
32843284
"mapping-books.json"
32853285
);
32863286
Rerank rerank = as(as(plan, Limit.class).child(), Rerank.class);
@@ -3290,7 +3290,10 @@ public void testResolveRerankInferenceId() {
32903290
{
32913291
VerificationException ve = expectThrows(
32923292
VerificationException.class,
3293-
() -> analyze("FROM books | RERANK \"italian food recipe\" ON title WITH \"completion-inference-id\"", "mapping-books.json")
3293+
() -> analyze(
3294+
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH \"completion-inference-id\"",
3295+
"mapping-books.json"
3296+
)
32943297

32953298
);
32963299
assertThat(
@@ -3305,7 +3308,10 @@ public void testResolveRerankInferenceId() {
33053308
{
33063309
VerificationException ve = expectThrows(
33073310
VerificationException.class,
3308-
() -> analyze("FROM books | RERANK \"italian food recipe\" ON title WITH \"error-inference-id\"", "mapping-books.json")
3311+
() -> analyze(
3312+
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH \"error-inference-id\"",
3313+
"mapping-books.json"
3314+
)
33093315

33103316
);
33113317
assertThat(ve.getMessage(), containsString("error with inference resolution"));
@@ -3314,10 +3320,13 @@ public void testResolveRerankInferenceId() {
33143320
{
33153321
VerificationException ve = expectThrows(
33163322
VerificationException.class,
3317-
() -> analyze("FROM books | RERANK \"italian food recipe\" ON title WITH \"unknow-inference-id\"", "mapping-books.json")
3323+
() -> analyze(
3324+
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH \"unknown-inference-id\"",
3325+
"mapping-books.json"
3326+
)
33183327

33193328
);
3320-
assertThat(ve.getMessage(), containsString("unresolved inference [unknow-inference-id]"));
3329+
assertThat(ve.getMessage(), containsString("unresolved inference [unknown-inference-id]"));
33213330
}
33223331
}
33233332

@@ -3327,9 +3336,9 @@ public void testResolveRerankFields() {
33273336
{
33283337
// Single field.
33293338
LogicalPlan plan = analyze("""
3330-
FROM books
3339+
FROM books METADATA _score
33313340
| WHERE title:"italian food recipe" OR description:"italian food recipe"
3332-
| KEEP description, title, year
3341+
| KEEP description, title, year, _score
33333342
| DROP description
33343343
| RERANK "italian food recipe" ON title WITH "reranking-inference-id"
33353344
""", "mapping-books.json");
@@ -3352,7 +3361,7 @@ public void testResolveRerankFields() {
33523361
{
33533362
// Multiple fields.
33543363
LogicalPlan plan = analyze("""
3355-
FROM books
3364+
FROM books METADATA _score
33563365
| WHERE title:"food"
33573366
| RERANK "food" ON title, description=SUBSTRING(description, 0, 100), yearRenamed=year WITH "reranking-inference-id"
33583367
""", "mapping-books.json");
@@ -3392,7 +3401,7 @@ public void testResolveRerankFields() {
33923401
VerificationException ve = expectThrows(
33933402
VerificationException.class,
33943403
() -> analyze(
3395-
"FROM books | RERANK \"italian food recipe\" ON missingField WITH \"reranking-inference-id\"",
3404+
"FROM books METADATA _score | RERANK \"italian food recipe\" ON missingField WITH \"reranking-inference-id\"",
33963405
"mapping-books.json"
33973406
)
33983407

@@ -3401,6 +3410,16 @@ public void testResolveRerankFields() {
34013410
}
34023411
}
34033412

3413+
public void testRerankRequiresScore() {
3414+
3415+
VerificationException ve = expectThrows(
3416+
VerificationException.class,
3417+
() -> analyze("FROM books | RERANK \"italian food recipe\" ON title WITH \"reranking-inference-id\"", "mapping-books.json")
3418+
3419+
);
3420+
assertThat(ve.getMessage(), containsString("Missing required column [_score] for RERANK"));
3421+
}
3422+
34043423
@Override
34053424
protected IndexAnalyzers createDefaultIndexAnalyzers() {
34063425
return super.createDefaultIndexAnalyzers();

0 commit comments

Comments
 (0)