Skip to content

Commit 69d551f

Browse files
committed
Adding the _score column automatically if missing from previous step.
1 parent 32dbb2c commit 69d551f

File tree

7 files changed

+129
-22
lines changed

7 files changed

+129
-22
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/rerank.csv-spec

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,22 @@ book_no:keyword | title:text | author:text
7474
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222222276031971
7575
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083333395421505
7676
;
77+
78+
79+
reranker add the _score column when missing
80+
required_capability: rerank
81+
required_capability: match_function
82+
83+
FROM books
84+
| WHERE title:"war and peace" AND author:"Tolstoy"
85+
| RERANK "war and peace" ON title WITH "test_reranker"
86+
| KEEP book_no, title, author, _score
87+
;
88+
89+
90+
book_no:keyword | title:text | author:text | _score:double
91+
5327 | War and Peace | Leo Tolstoy | 0.03846153989434242
92+
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222222276031971
93+
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083333395421505
94+
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01515151560306549
95+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.analysis;
99

1010
import org.elasticsearch.common.logging.HeaderWarning;
11+
import org.elasticsearch.common.logging.LoggerMessageFormat;
1112
import org.elasticsearch.compute.data.Block;
1213
import org.elasticsearch.core.Strings;
1314
import org.elasticsearch.index.IndexMode;
@@ -127,7 +128,6 @@
127128

128129
import static java.util.Collections.emptyList;
129130
import static java.util.Collections.singletonList;
130-
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
131131
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE;
132132
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
133133
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
@@ -752,17 +752,12 @@ private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childrenOutput)
752752
rerank = rerank.withRerankFields(newFields);
753753
}
754754

755-
// Ensure the score attribute is resolved
755+
// Ensure the score attribute is present in the output.
756756
if (rerank.scoreAttribute() instanceof UnresolvedAttribute ua) {
757757
Attribute resolved = resolveAttribute(ua, childrenOutput);
758-
if (resolved.resolved() == false) {
759-
resolved = ua.withUnresolvedMessage(format(null, "Missing required column [{}] for RERANK", MetadataAttribute.SCORE));
760-
} else if (resolved.dataType() != DOUBLE) {
761-
resolved = ua.withUnresolvedMessage(
762-
format("_score has the wrong type; [{}] expected but got [{}]", DataType.DOUBLE, resolved.dataType())
763-
);
758+
if (resolved.resolved() == false || resolved.dataType() != DOUBLE) {
759+
resolved = MetadataAttribute.create(Source.EMPTY, MetadataAttribute.SCORE);
764760
}
765-
766761
rerank = rerank.withScoreAttribute(resolved);
767762
}
768763

@@ -1081,7 +1076,7 @@ public static List<NamedExpression> projectionsForRename(Rename rename, List<Att
10811076
var u = resolved;
10821077
var previousAliasName = reverseAliasing.get(resolved.name());
10831078
if (previousAliasName != null) {
1084-
String message = format(
1079+
String message = LoggerMessageFormat.format(
10851080
null,
10861081
"Column [{}] renamed to [{}] and is no longer available [{}]",
10871082
resolved.name(),
@@ -1491,7 +1486,7 @@ private static boolean supportsStringImplicitCasting(DataType type) {
14911486
}
14921487

14931488
private static UnresolvedAttribute unresolvedAttribute(Expression value, String type, Exception e) {
1494-
String message = format(
1489+
String message = LoggerMessageFormat.format(
14951490
"Cannot convert string [{}] to [{}], error [{}]",
14961491
value.fold(FoldContext.small() /* TODO remove me */),
14971492
type,

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ private Page buildOutput(Page inputPage, InferenceAction.Response inferenceRespo
156156
}
157157

158158
private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) {
159-
int blockCount = inputPage.getBlockCount();
159+
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
160160
Block[] blocks = new Block[blockCount];
161161

162162
try {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.xpack.esql.core.tree.Source;
2424
import org.elasticsearch.xpack.esql.expression.Order;
2525
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
26+
import org.elasticsearch.xpack.esql.plan.QueryPlan;
2627
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2728
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
2829
import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
@@ -34,13 +35,15 @@
3435
import java.util.Objects;
3536

3637
import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes;
38+
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
3739

3840
public class Rerank extends InferencePlan implements SortAgnostic, SurrogateLogicalPlan {
3941

4042
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
4143
private final Attribute scoreAttribute;
4244
private final Expression queryText;
4345
private final List<Alias> rerankFields;
46+
private List<Attribute> lazyOutput;
4447

4548
public Rerank(Source source, LogicalPlan child, Expression inferenceId, Expression queryText, List<Alias> rerankFields) {
4649
super(source, child, inferenceId);
@@ -126,7 +129,11 @@ public UnaryPlan replaceChild(LogicalPlan newChild) {
126129
@Override
127130
protected AttributeSet computeReferences() {
128131
AttributeSet refs = computeReferences(rerankFields);
129-
refs.add(scoreAttribute);
132+
133+
if (planHasAttribute(child(), scoreAttribute)) {
134+
refs.add(scoreAttribute);
135+
}
136+
130137
return refs;
131138
}
132139

@@ -166,4 +173,19 @@ public LogicalPlan surrogate() {
166173
Order sortOrder = new Order(source(), scoreAttribute, Order.OrderDirection.DESC, Order.NullsPosition.ANY);
167174
return new OrderBy(source(), this, List.of(sortOrder));
168175
}
176+
177+
@Override
178+
public List<Attribute> output() {
179+
if (lazyOutput == null) {
180+
lazyOutput = planHasAttribute(child(), scoreAttribute)
181+
? child().output()
182+
: mergeOutputAttributes(List.of(scoreAttribute), child().output());
183+
}
184+
185+
return lazyOutput;
186+
}
187+
188+
public static boolean planHasAttribute(QueryPlan<?> plan, Attribute attribute) {
189+
return plan.output().stream().anyMatch(attr -> attr.equals(attribute));
190+
}
169191
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,22 @@
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.xpack.esql.core.expression.Alias;
1414
import org.elasticsearch.xpack.esql.core.expression.Attribute;
15+
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
1516
import org.elasticsearch.xpack.esql.core.expression.Expression;
1617
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1718
import org.elasticsearch.xpack.esql.core.tree.Source;
1819
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
20+
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
1921
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
2022
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
2123

2224
import java.io.IOException;
2325
import java.util.List;
2426
import java.util.Objects;
2527

28+
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
29+
import static org.elasticsearch.xpack.esql.plan.logical.inference.Rerank.planHasAttribute;
30+
2631
public class RerankExec extends InferenceExec {
2732

2833
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
@@ -95,6 +100,26 @@ public UnaryExec replaceChild(PhysicalPlan newChild) {
95100
return new RerankExec(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
96101
}
97102

103+
@Override
104+
public List<Attribute> output() {
105+
if (planHasAttribute(child(), scoreAttribute)) {
106+
return child().output();
107+
}
108+
109+
return mergeOutputAttributes(List.of(scoreAttribute), child().output());
110+
}
111+
112+
@Override
113+
protected AttributeSet computeReferences() {
114+
AttributeSet refs = Rerank.computeReferences(rerankFields);
115+
116+
if (planHasAttribute(child(), scoreAttribute)) {
117+
refs.add(scoreAttribute);
118+
}
119+
120+
return refs;
121+
}
122+
98123
@Override
99124
public boolean equals(Object o) {
100125
if (this == o) return true;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,11 +568,18 @@ private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerCon
568568

569569
String inferenceId = BytesRefs.toString(rerank.inferenceId().fold(context.foldCtx));
570570
String queryText = BytesRefs.toString(rerank.queryText().fold(context.foldCtx));
571-
int scoreChannel = source.layout.get(rerank.scoreAttribute().id()).channel();
571+
572+
Layout.Builder layoutBuilder = source.layout.builder();
573+
if (source.layout.get(rerank.scoreAttribute().id()) == null) {
574+
layoutBuilder.append(rerank.scoreAttribute());
575+
}
576+
Layout outputLayout = layoutBuilder.build();
577+
578+
int scoreChannel = outputLayout.get(rerank.scoreAttribute().id()).channel();
572579

573580
return source.with(
574581
new RerankOperator.Factory(inferenceService, inferenceId, queryText, rerankFieldsEvaluatorSuppliers, scoreChannel),
575-
source.layout
582+
outputLayout
576583
);
577584
}
578585

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3358,6 +3358,10 @@ public void testResolveRerankFields() {
33583358
assertThat(rerank.queryText(), equalTo(string("italian food recipe")));
33593359
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
33603360
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", titleAttribute))));
3361+
assertThat(
3362+
rerank.scoreAttribute(),
3363+
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
3364+
);
33613365
}
33623366

33633367
{
@@ -3397,6 +3401,10 @@ public void testResolveRerankFields() {
33973401
Attribute yearAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("year")).findFirst().get();
33983402
assertThat(yearAttribute, notNullValue());
33993403
assertThat(rerank.rerankFields().get(2), equalTo(alias("yearRenamed", yearAttribute)));
3404+
assertThat(
3405+
rerank.scoreAttribute(),
3406+
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
3407+
);
34003408
}
34013409

34023410
{
@@ -3412,15 +3420,46 @@ public void testResolveRerankFields() {
34123420
}
34133421
}
34143422

3415-
public void testRerankRequiresScore() {
3416-
assumeTrue("Requires RERANK command", EsqlCapabilities.Cap.RERANK.isEnabled());
3423+
public void testResolveRerankScoreFields() {
3424+
{
3425+
// When the metadata field is required in FROM, it is reused.
3426+
LogicalPlan plan = analyze("""
3427+
FROM books METADATA _score
3428+
| WHERE title:"italian food recipe" OR description:"italian food recipe"
3429+
| RERANK "italian food recipe" ON title WITH "reranking-inference-id"
3430+
""", "mapping-books.json");
34173431

3418-
VerificationException ve = expectThrows(
3419-
VerificationException.class,
3420-
() -> analyze("FROM books | RERANK \"italian food recipe\" ON title WITH \"reranking-inference-id\"", "mapping-books.json")
3432+
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
3433+
Rerank rerank = as(limit.child(), Rerank.class);
3434+
Filter filter = as(rerank.child(), Filter.class);
3435+
EsRelation relation = as(filter.child(), EsRelation.class);
34213436

3422-
);
3423-
assertThat(ve.getMessage(), containsString("Missing required column [_score] for RERANK"));
3437+
Attribute metadataScoreAttribute = relation.output()
3438+
.stream()
3439+
.filter(attr -> attr.name().equals(MetadataAttribute.SCORE))
3440+
.findFirst()
3441+
.get();
3442+
assertThat(rerank.scoreAttribute(), equalTo(metadataScoreAttribute));
3443+
assertThat(rerank.output(), hasItem(metadataScoreAttribute));
3444+
}
3445+
3446+
{
3447+
// When the metadata field is not required in FROM, it is added to the output of RERANK
3448+
LogicalPlan plan = analyze("""
3449+
FROM books
3450+
| WHERE title:"italian food recipe" OR description:"italian food recipe"
3451+
| RERANK "italian food recipe" ON title WITH "reranking-inference-id"
3452+
""", "mapping-books.json");
3453+
3454+
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
3455+
Rerank rerank = as(limit.child(), Rerank.class);
3456+
Filter filter = as(rerank.child(), Filter.class);
3457+
EsRelation relation = as(filter.child(), EsRelation.class);
3458+
3459+
assertThat(relation.output().stream().noneMatch(attr -> attr.name().equals(MetadataAttribute.SCORE)), is(true));
3460+
assertThat(rerank.scoreAttribute(), equalTo(MetadataAttribute.create(EMPTY, MetadataAttribute.SCORE)));
3461+
assertThat(rerank.output(), hasItem(rerank.scoreAttribute()));
3462+
}
34243463
}
34253464

34263465
@Override

0 commit comments

Comments
 (0)