Skip to content

Commit 6818567

Browse files
committed
Rerank command logical plan parsing
1 parent b2c0e11 commit 6818567

File tree

5 files changed

+165
-1
lines changed

5 files changed

+165
-1
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,12 @@ public enum Cap {
834834
/**
835835
* Support for FORK command
836836
*/
837-
FORK(Build.current().isSnapshot());
837+
FORK(Build.current().isSnapshot()),
838+
839+
/**
840+
* Support for RERANK command
841+
*/
842+
RERANK(Build.current().isSnapshot());
838843

839844
private final boolean enabled;
840845

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,30 @@ public Literal visitIntegerValue(EsqlBaseParser.IntegerValueContext ctx) {
194194
return new Literal(source, val, type);
195195
}
196196

197+
@Override
198+
public String visitStringOrParameter(EsqlBaseParser.StringOrParameterContext ctx) {
199+
if (ctx.parameter() != null) {
200+
if (expression(ctx.parameter()) instanceof Literal lit) {
201+
if (lit.value() == null) {
202+
throw new ParsingException(
203+
source(ctx.parameter()),
204+
"Query parameter [{}] is null or undefined and cannot be used as string",
205+
ctx.parameter().getText()
206+
);
207+
}
208+
return lit.value().toString();
209+
}
210+
211+
throw new ParsingException(
212+
source(ctx.parameter()),
213+
"Query parameter [{}], cannot be used as string",
214+
ctx.parameter().getText()
215+
);
216+
}
217+
218+
return unquote(source(ctx.string()));
219+
}
220+
197221
@Override
198222
public Object visitNumericArrayLiteral(EsqlBaseParser.NumericArrayLiteralContext ctx) {
199223
Source source = source(ctx);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import org.elasticsearch.xpack.esql.plan.logical.Rename;
6363
import org.elasticsearch.xpack.esql.plan.logical.Row;
6464
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
65+
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
6566
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
6667
import org.elasticsearch.xpack.esql.plan.logical.join.StubRelation;
6768
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
@@ -686,4 +687,27 @@ public PlanFactory visitCompositeForkSubQuery(EsqlBaseParser.CompositeForkSubQue
686687
PlanFactory makePlan = typedParsing(this, ctx.forkSubQueryProcessingCommand(), PlanFactory.class);
687688
return input -> makePlan.apply(lowerPlan.apply(input));
688689
}
690+
691+
public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
692+
var source = source(ctx);
693+
694+
if (false == EsqlCapabilities.Cap.RERANK.isEnabled()) {
695+
throw new ParsingException(source, "RERANK is in preview and only available in SNAPSHOT build");
696+
}
697+
698+
return p -> {
699+
List<Alias> rerankFields = new ArrayList<>();
700+
if (ctx.fields() != null) {
701+
rerankFields = visitFields(ctx.fields());
702+
} else {
703+
for (var attribute : p.output()) {
704+
rerankFields.add(new Alias(Source.EMPTY, attribute.name(), new UnresolvedAttribute(Source.EMPTY, attribute.name())));
705+
}
706+
}
707+
708+
System.out.println(rerankFields);
709+
710+
return new Rerank(source, p, visitStringOrParameter(ctx.inferenceId), visitStringOrParameter(ctx.queryText), rerankFields);
711+
};
712+
}
689713
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/GrammarInDevelopmentParsingTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ public void testDevelopmentMatch() throws Exception {
3030
parse("row a = 1 | match foo", "match");
3131
}
3232

33+
public void testDevelopmentRerank() {
34+
parse("row a = 1 | rerank \"foo\" ON title WITH reranker", "rerank");
35+
}
36+
3337
void parse(String query, String errorMessage) {
3438
ParsingException pe = expectThrows(ParsingException.class, () -> parser().createStatement(query));
3539
assertThat(pe.getMessage(), containsString("mismatched input '" + errorMessage + "'"));

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import org.elasticsearch.xpack.esql.plan.logical.Rename;
6565
import org.elasticsearch.xpack.esql.plan.logical.Row;
6666
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
67+
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
6768
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
6869
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
6970

@@ -3088,6 +3089,112 @@ public void testInvalidFork() {
30883089
expectError("FROM foo* | FORK ( LIMIT 10 ) ( y+2 )", "line 1:33: mismatched input 'y' expecting {'limit', 'sort', 'where'}");
30893090
}
30903091

3092+
public void testRerankSingleField() {
3093+
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
3094+
3095+
var plan = processingCommand("RERANK \"query text\" ON title WITH \"inferenceID\"");
3096+
var rerank = as(plan, Rerank.class);
3097+
3098+
assertThat(rerank.queryText(), equalTo("query text"));
3099+
assertThat(rerank.inferenceId(), equalTo("inferenceID"));
3100+
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
3101+
}
3102+
3103+
public void testRerankMultipleFields() {
3104+
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
3105+
3106+
var plan = processingCommand("RERANK \"query text\" ON title, description, authors_renamed=authors WITH \"inferenceID\"");
3107+
var rerank = as(plan, Rerank.class);
3108+
3109+
assertThat(rerank.queryText(), equalTo("query text"));
3110+
assertThat(rerank.inferenceId(), equalTo("inferenceID"));
3111+
assertThat(
3112+
rerank.rerankFields(),
3113+
equalTo(
3114+
List.of(
3115+
alias("title", attribute("title")),
3116+
alias("description", attribute("description")),
3117+
alias("authors_renamed", attribute("authors"))
3118+
)
3119+
)
3120+
);
3121+
}
3122+
3123+
public void testRerankComputedFields() {
3124+
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
3125+
3126+
var plan = processingCommand(
3127+
"RERANK \"query text\" ON title, short_description = SUBSTRING(description, 0, 100) WITH \"inferenceID\""
3128+
);
3129+
var rerank = as(plan, Rerank.class);
3130+
3131+
assertThat(rerank.queryText(), equalTo("query text"));
3132+
assertThat(rerank.inferenceId(), equalTo("inferenceID"));
3133+
assertThat(
3134+
rerank.rerankFields(),
3135+
equalTo(
3136+
List.of(
3137+
alias("title", attribute("title")),
3138+
alias(
3139+
"short_description",
3140+
function(
3141+
"SUBSTRING",
3142+
List.of(attribute("description"), new Literal(EMPTY, 0, INTEGER), new Literal(EMPTY, 100, INTEGER))
3143+
)
3144+
)
3145+
)
3146+
)
3147+
);
3148+
}
3149+
3150+
public void testRerankNoFieldSpecified() {
3151+
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
3152+
3153+
var plan = parser.createStatement("row a = 1, b = \"foo\", c=true | RERANK \"query text\" WITH \"inferenceID\"");
3154+
var rerank = as(plan, Rerank.class);
3155+
3156+
assertThat(rerank.queryText(), equalTo("query text"));
3157+
assertThat(rerank.inferenceId(), equalTo("inferenceID"));
3158+
3159+
// When no field are specified, all the fields are used.
3160+
assertThat(rerank.rerankFields(), contains(alias("a", attribute("a")), alias("b", attribute("b")), alias("c", attribute("c"))));
3161+
}
3162+
3163+
public void testRerankWithPositionalParameters() {
3164+
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
3165+
3166+
var queryParams = new QueryParams(List.of(paramAsConstant(null, "query text"), paramAsConstant(null, "reranker")));
3167+
var rerank = as(parser.createStatement("row a = 1 | RERANK ? ON title WITH ?", queryParams), Rerank.class);
3168+
3169+
assertThat(rerank.queryText(), equalTo("query text"));
3170+
assertThat(rerank.inferenceId(), equalTo("reranker"));
3171+
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
3172+
}
3173+
3174+
public void testRerankWithNamedParameters() {
3175+
assumeTrue("FORK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
3176+
3177+
var queryParams = new QueryParams(List.of(paramAsConstant("queryText", "query text"), paramAsConstant("inferenceId", "reranker")));
3178+
var rerank = as(parser.createStatement("row a = 1 | RERANK ?queryText ON title WITH ?inferenceId", queryParams), Rerank.class);
3179+
3180+
assertThat(rerank.queryText(), equalTo("query text"));
3181+
assertThat(rerank.inferenceId(), equalTo("reranker"));
3182+
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
3183+
}
3184+
3185+
public void testInvalidRerank() {
3186+
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
3187+
expectError(
3188+
"FROM foo* | RERANK ON title WITH inferenceId",
3189+
"line 1:20: mismatched input 'ON' expecting {QUOTED_STRING, '?', NAMED_OR_POSITIONAL_PARAM}"
3190+
);
3191+
3192+
expectError(
3193+
"FROM foo* | RERANK \"query text\" ON title",
3194+
"line 1:42: mismatched input '<EOF>' expecting {'with', 'and', '::', ',', '.', 'or', '+', '-', '*', '/', '%'}"
3195+
);
3196+
}
3197+
30913198
static Alias alias(String name, Expression value) {
30923199
return new Alias(EMPTY, name, value);
30933200
}

0 commit comments

Comments
 (0)