Skip to content

Commit 6734c6f

Browse files
committed
Change the way inferenceId is passed to the command.
1 parent 1fdb090 commit 6734c6f

File tree

9 files changed

+137
-58
lines changed

9 files changed

+137
-58
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/rerank.csv-spec

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ required_capability: match_operator_colon
99

1010
FROM books METADATA _score
1111
| WHERE title:"war and peace" AND author:"Tolstoy"
12-
| RERANK "war and peace" ON title WITH test_reranker
12+
| RERANK "war and peace" ON title WITH inferenceId=test_reranker
1313
| KEEP book_no, title, author
1414
;
1515

@@ -27,7 +27,7 @@ required_capability: match_operator_colon
2727

2828
FROM books METADATA _score
2929
| WHERE title:"war and peace" AND author:"Tolstoy"
30-
| RERANK "war and peace" ON title, author WITH test_reranker
30+
| RERANK "war and peace" ON title, author inferenceId=test_reranker
3131
| KEEP book_no, title, author
3232
;
3333

@@ -47,7 +47,7 @@ FROM books METADATA _score
4747
| WHERE title:"war and peace" AND author:"Tolstoy"
4848
| SORT _score DESC
4949
| LIMIT 3
50-
| RERANK "war and peace" ON title WITH test_reranker
50+
| RERANK "war and peace" ON title inferenceId=test_reranker
5151
| KEEP book_no, title, author
5252
;
5353

@@ -64,7 +64,7 @@ required_capability: match_operator_colon
6464

6565
FROM books METADATA _score
6666
| WHERE title:"war and peace" AND author:"Tolstoy"
67-
| RERANK "war and peace" ON title WITH test_reranker
67+
| RERANK "war and peace" ON title inferenceId=test_reranker
6868
| KEEP book_no, title, author
6969
| LIMIT 3
7070
;
@@ -82,7 +82,7 @@ required_capability: match_operator_colon
8282

8383
FROM books
8484
| WHERE title:"war and peace" AND author:"Tolstoy"
85-
| RERANK "war and peace" ON title WITH test_reranker
85+
| RERANK "war and peace" ON title inferenceId=test_reranker
8686
| KEEP book_no, title, author
8787
| SORT author, title
8888
| LIMIT 3
@@ -105,7 +105,7 @@ FROM books METADATA _id, _index, _score
105105
| FORK ( WHERE title:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
106106
( WHERE author:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
107107
| RRF
108-
| RERANK "Tolkien" ON title WITH test_reranker
108+
| RERANK "Tolkien" ON title inferenceId=test_reranker
109109
| LIMIT 2
110110
| KEEP book_no, title, author
111111
;

x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ rrfCommand
308308
;
309309

310310
rerankCommand
311-
: DEV_RERANK queryText=constant ON rerankFields (WITH inferenceId=identifierOrParameter)?
311+
: DEV_RERANK queryText=constant ON rerankFields (WITH commandOptions)?
312312
;
313313

314314
completionCommand

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,11 @@ private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childrenOutput)
835835
if (rerank.scoreAttribute() instanceof UnresolvedAttribute ua) {
836836
Attribute resolved = resolveAttribute(ua, childrenOutput);
837837
if (resolved.resolved() == false || resolved.dataType() != DOUBLE) {
838-
resolved = MetadataAttribute.create(Source.EMPTY, MetadataAttribute.SCORE);
838+
if (ua.name().equals(MetadataAttribute.SCORE)) {
839+
resolved = MetadataAttribute.create(Source.EMPTY, MetadataAttribute.SCORE);
840+
} else {
841+
resolved = new ReferenceAttribute(resolved.source(), resolved.name(), DOUBLE);
842+
}
839843
}
840844
rerank = rerank.withScoreAttribute(resolved);
841845
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.java

Lines changed: 4 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,9 @@ public PlanFactory visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx) {
723723
@Override
724724
public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
725725
Source source = source(ctx);
726+
List<Alias> rerankFields = visitRerankFields(ctx.rerankFields());
726727
Expression queryText = expression(ctx.queryText);
728+
727729
if (queryText instanceof Literal queryTextLiteral && DataType.isString(queryText.dataType())) {
728730
if (queryTextLiteral.value() == null) {
729731
throw new ParsingException(
@@ -740,46 +742,82 @@ public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
740742
);
741743
}
742744

743-
Literal inferenceId = ctx.inferenceId != null
744-
? inferenceId(ctx.inferenceId)
745-
: new Literal(source, Rerank.DEFAULT_INFERENCE_ID, KEYWORD);
745+
return p -> visitRerankOptions(new Rerank.Builder(source, p, queryText, rerankFields), ctx.commandOptions()).build();
746+
}
747+
748+
private Rerank.Builder visitRerankOptions(Rerank.Builder rerannkBuilder, EsqlBaseParser.CommandOptionsContext ctx) {
749+
if (ctx == null) {
750+
return rerannkBuilder;
751+
}
746752

747-
return p -> new Rerank(source, p, inferenceId, queryText, visitRerankFields(ctx.rerankFields()));
753+
for (var option : ctx.commandOption()) {
754+
String optionName = visitIdentifier(option.identifier());
755+
if (optionName.equals(Rerank.Builder.INFERENCE_ID_OPTION_NAME)) {
756+
rerannkBuilder.withInferenceId(visitInferenceId(expression(option.primaryExpression())));
757+
} else if (optionName.equals(Rerank.Builder.SCORE_COLUMN_OPTION_NAME)) {
758+
if (expression(option.primaryExpression()) instanceof UnresolvedAttribute scoreAttribute) {
759+
rerannkBuilder.withScoreColumnAttribute(scoreAttribute);
760+
} else {
761+
throw new ParsingException(
762+
source(option.identifier()),
763+
"Option [{}] expects a valid attribute in RERANK command. [{}] provided.",
764+
option.identifier().getText(),
765+
option.primaryExpression().getText()
766+
);
767+
}
768+
} else {
769+
throw new ParsingException(
770+
source(option.identifier()),
771+
"Unknow parameter [{}] in RERANK command",
772+
option.identifier().getText()
773+
);
774+
}
775+
}
776+
777+
return rerannkBuilder;
748778
}
749779

750780
@Override
751781
public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) {
752782
Source source = source(ctx);
753783
Expression prompt = expression(ctx.prompt);
754-
Literal inferenceId = inferenceId(ctx.inferenceId);
784+
Literal inferenceId = visitInferenceId(ctx.inferenceId);
755785
Attribute targetField = ctx.targetField == null
756786
? new UnresolvedAttribute(source, Completion.DEFAULT_OUTPUT_FIELD_NAME)
757787
: visitQualifiedName(ctx.targetField);
758788

759789
return p -> new Completion(source, p, inferenceId, prompt, targetField);
760790
}
761791

762-
public Literal inferenceId(EsqlBaseParser.IdentifierOrParameterContext ctx) {
792+
public Literal visitInferenceId(EsqlBaseParser.IdentifierOrParameterContext ctx) {
763793
if (ctx.identifier() != null) {
764794
return new Literal(source(ctx), visitIdentifier(ctx.identifier()), KEYWORD);
765795
}
766796

767-
if (expression(ctx.parameter()) instanceof Literal literalParam) {
768-
if (literalParam.value() != null) {
769-
return literalParam;
797+
return visitInferenceId(expression(ctx.parameter()));
798+
}
799+
800+
public Literal visitInferenceId(Expression expression) {
801+
if (expression instanceof Literal literal) {
802+
if (literal.value() == null) {
803+
throw new ParsingException(
804+
expression.source(),
805+
"Query parameter [{}] is null or undefined and cannot be used as inference id",
806+
expression.source().text()
807+
);
770808
}
771809

772-
throw new ParsingException(
773-
source(ctx.parameter()),
774-
"Query parameter [{}] is null or undefined and cannot be used as inference id",
775-
ctx.parameter().getText()
776-
);
810+
return literal;
811+
} else if (expression instanceof UnresolvedAttribute attribute) {
812+
// Support for unquoted inference id
813+
return new Literal(expression.source(), attribute.name(), KEYWORD);
777814
}
778815

779816
throw new ParsingException(
780-
source(ctx.parameter()),
781-
"Query parameter [{}] is not a string and cannot be used as inference id",
782-
ctx.parameter().getText()
817+
expression.source(),
818+
"Query parameter [{}] is not a string and cannot be used as inference id [{}]",
819+
expression.source().text(),
820+
expression.getClass()
783821
);
784822
}
785823

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
1919
import org.elasticsearch.xpack.esql.core.expression.Expression;
2020
import org.elasticsearch.xpack.esql.core.expression.Expressions;
21+
import org.elasticsearch.xpack.esql.core.expression.Literal;
2122
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
2223
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
2324
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
2425
import org.elasticsearch.xpack.esql.core.tree.Source;
26+
import org.elasticsearch.xpack.esql.core.type.DataType;
2527
import org.elasticsearch.xpack.esql.expression.Order;
2628
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2729
import org.elasticsearch.xpack.esql.plan.QueryPlan;
@@ -37,23 +39,17 @@
3739

3840
import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes;
3941
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
42+
import static org.elasticsearch.xpack.esql.parser.ParserUtils.source;
4043

4144
public class Rerank extends InferencePlan<Rerank> implements SortAgnostic, SurrogateLogicalPlan, TelemetryAware {
4245

4346
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
44-
public static final Object DEFAULT_INFERENCE_ID = ".rerank-v1-elasticsearch";
47+
public static final String DEFAULT_INFERENCE_ID = ".rerank-v1-elasticsearch";
4548
private final Attribute scoreAttribute;
4649
private final Expression queryText;
4750
private final List<Alias> rerankFields;
4851
private List<Attribute> lazyOutput;
4952

50-
public Rerank(Source source, LogicalPlan child, Expression inferenceId, Expression queryText, List<Alias> rerankFields) {
51-
super(source, child, inferenceId);
52-
this.queryText = queryText;
53-
this.rerankFields = rerankFields;
54-
this.scoreAttribute = new UnresolvedAttribute(source, MetadataAttribute.SCORE);
55-
}
56-
5753
public Rerank(
5854
Source source,
5955
LogicalPlan child,
@@ -189,4 +185,38 @@ public List<Attribute> output() {
189185
public static boolean planHasAttribute(QueryPlan<?> plan, Attribute attribute) {
190186
return plan.outputSet().stream().anyMatch(attr -> attr.equals(attribute));
191187
}
188+
189+
public static class Builder {
190+
191+
public static final String INFERENCE_ID_OPTION_NAME = "inferenceId";
192+
public static final String SCORE_COLUMN_OPTION_NAME = "scoreColumn";
193+
194+
private final Source source;
195+
private final LogicalPlan child;
196+
private final Expression queryText;
197+
private final List<Alias> rerankFields;
198+
private Expression inferenceId = new Literal(Source.EMPTY, Rerank.DEFAULT_INFERENCE_ID, DataType.KEYWORD);
199+
private Attribute scoreAttribute = new UnresolvedAttribute(Source.EMPTY, MetadataAttribute.SCORE);
200+
201+
public Builder(Source source, LogicalPlan child, Expression queryText, List<Alias> rerankFields) {
202+
this.source = source;
203+
this.child = child;
204+
this.queryText = queryText;
205+
this.rerankFields = rerankFields;
206+
}
207+
208+
public Builder withInferenceId(Expression inferenceId) {
209+
this.inferenceId = inferenceId;
210+
return this;
211+
}
212+
213+
public Builder withScoreColumnAttribute(Attribute scoreAttribute) {
214+
this.scoreAttribute = scoreAttribute;
215+
return this;
216+
}
217+
218+
public Rerank build() {
219+
return new Rerank(source, child, inferenceId, queryText, rerankFields, scoreAttribute);
220+
}
221+
}
192222
}

0 commit comments

Comments
 (0)