Skip to content

Commit e89a74b

Browse files
committed
Validate rerank fields types.
1 parent 281a7d8 commit e89a74b

File tree

3 files changed

+99
-17
lines changed

3 files changed

+99
-17
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -758,18 +758,6 @@ public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
758758
Expression queryText = expression(ctx.queryText);
759759
Attribute scoreAttribute = visitQualifiedName(ctx.targetField, new UnresolvedAttribute(source, MetadataAttribute.SCORE));
760760

761-
if (queryText instanceof Literal queryTextLiteral && DataType.isString(queryText.dataType())) {
762-
if (queryTextLiteral.value() == null) {
763-
throw new ParsingException(source(ctx.queryText), "Query cannot be null or undefined in RERANK", ctx.queryText.getText());
764-
}
765-
} else {
766-
throw new ParsingException(
767-
source(ctx.queryText),
768-
"Query must be a valid string in RERANK, found [{}]",
769-
ctx.queryText.getText()
770-
);
771-
}
772-
773761
return p -> {
774762
checkForRemoteClusters(p, source, "RERANK");
775763
return applyRerankOptions(new Rerank(source, p, queryText, rerankFields, scoreAttribute), ctx.commandNamedParameters());

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,33 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisVerificationAware;
1415
import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
16+
import org.elasticsearch.xpack.esql.common.Failures;
1517
import org.elasticsearch.xpack.esql.core.capabilities.Resolvables;
1618
import org.elasticsearch.xpack.esql.core.expression.Alias;
1719
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1820
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
1921
import org.elasticsearch.xpack.esql.core.expression.Expression;
20-
import org.elasticsearch.xpack.esql.core.expression.Expressions;
2122
import org.elasticsearch.xpack.esql.core.expression.Literal;
2223
import org.elasticsearch.xpack.esql.core.expression.NameId;
2324
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
2425
import org.elasticsearch.xpack.esql.core.tree.Source;
26+
import org.elasticsearch.xpack.esql.core.type.DataType;
2527
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
28+
import org.elasticsearch.xpack.esql.plan.logical.Eval;
2629
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2730
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
2831

2932
import java.io.IOException;
3033
import java.util.List;
3134
import java.util.Objects;
35+
import java.util.function.Predicate;
3236

33-
import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes;
37+
import static org.elasticsearch.xpack.esql.common.Failure.fail;
3438
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
3539

36-
public class Rerank extends InferencePlan<Rerank> implements TelemetryAware {
40+
public class Rerank extends InferencePlan<Rerank> implements PostAnalysisVerificationAware, TelemetryAware {
3741

3842
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
3943
public static final String DEFAULT_INFERENCE_ID = ".rerank-v1-elasticsearch";
@@ -155,8 +159,7 @@ private Attribute renameScoreAttribute(String newName) {
155159
}
156160

157161
public static AttributeSet computeReferences(List<Alias> fields) {
158-
AttributeSet rerankFields = AttributeSet.of(asAttributes(fields));
159-
return Expressions.references(fields).subtract(rerankFields);
162+
return Eval.computeReferences(fields);
160163
}
161164

162165
@Override
@@ -169,6 +172,28 @@ protected NodeInfo<? extends LogicalPlan> info() {
169172
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
170173
}
171174

175+
@Override
176+
public void postAnalysisVerification(Failures failures) {
177+
if (queryText.resolved() && (DataType.isString(queryText.dataType()) == false || queryText.foldable() == false)) {
178+
// Rerank only supports string as query
179+
failures.add(fail(queryText, "query must be a valid string in RERANK, found [{}]", queryText.source().text()));
180+
}
181+
182+
// When using multiple fields the content is transformed into YAML before it is reranked
183+
// We can use any of string, numeric or boolean field.
184+
rerankFields.stream()
185+
.filter(Predicate.not(f -> DataType.isString(f.dataType()) || f.dataType() == DataType.BOOLEAN || f.dataType().isNumeric()))
186+
.forEach(
187+
rerankField -> failures.add(
188+
fail(
189+
rerankField,
190+
"rerank field must be a valid string, numeric or boolean expression, found [{}]",
191+
rerankField.source().text()
192+
)
193+
)
194+
);
195+
}
196+
172197
@Override
173198
public boolean equals(Object o) {
174199
if (this == o) return true;

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.common.settings.Settings;
1717
import org.elasticsearch.index.IndexMode;
1818
import org.elasticsearch.index.analysis.IndexAnalyzers;
19+
import org.elasticsearch.logging.LogManager;
1920
import org.elasticsearch.test.ESTestCase;
2021
import org.elasticsearch.xpack.esql.EsqlTestUtils;
2122
import org.elasticsearch.xpack.esql.LoadMapping;
@@ -3740,6 +3741,74 @@ public void testResolveRerankScoreField() {
37403741
}
37413742
}
37423743

3744+
public void testRerankInvalidQueryTypes() {
3745+
assertError("""
3746+
FROM books METADATA _score
3747+
| RERANK rerank_score = 42 ON title WITH { "inference_id" : "reranking-inference-id" }
3748+
""", "mapping-books.json", new QueryParams(), "query must be a valid string in RERANK, found [42]");
3749+
3750+
assertError("""
3751+
FROM books METADATA _score
3752+
| RERANK rerank_score = null ON title WITH { "inference_id" : "reranking-inference-id" }
3753+
""", "mapping-books.json", new QueryParams(), "query must be a valid string in RERANK, found [null]");
3754+
}
3755+
3756+
public void testRerankFieldsInvalidTypes() {
3757+
List<String> invalidFieldNames = List.of("date", "date_nanos", "ip", "version", "dense_vector");
3758+
3759+
for (String fieldName : invalidFieldNames) {
3760+
LogManager.getLogger(AnalyzerTests.class).warn("[{}]", fieldName);
3761+
assertError(
3762+
"FROM books METADATA _score | RERANK rerank_score = \"test query\" ON "
3763+
+ fieldName
3764+
+ " WITH { \"inference_id\" : \"reranking-inference-id\" }",
3765+
"mapping-all-types.json",
3766+
new QueryParams(),
3767+
"rerank field must be a valid string, numeric or boolean expression, found [" + fieldName + "]"
3768+
);
3769+
}
3770+
}
3771+
3772+
public void testRerankFieldValidTypes() {
3773+
List<String> validFieldNames = List.of(
3774+
"boolean",
3775+
"byte",
3776+
"constant_keyword-foo",
3777+
"double",
3778+
"float",
3779+
"half_float",
3780+
"scaled_float",
3781+
"integer",
3782+
"keyword",
3783+
"long",
3784+
"unsigned_long",
3785+
"short",
3786+
"text",
3787+
"wildcard"
3788+
);
3789+
3790+
for (String fieldName : validFieldNames) {
3791+
LogicalPlan plan = analyze(
3792+
"FROM books METADATA _score | RERANK rerank_score = \"test query\" ON `"
3793+
+ fieldName
3794+
+ "` WITH { \"inference_id\" : \"reranking-inference-id\" }",
3795+
"mapping-all-types.json"
3796+
);
3797+
3798+
Rerank rerank = as(as(plan, Limit.class).child(), Rerank.class);
3799+
EsRelation relation = as(rerank.child(), EsRelation.class);
3800+
Attribute fieldAttribute = getAttributeByName(relation.output(), fieldName);
3801+
assertThat(rerank.rerankFields(), equalTo(List.of(alias(fieldName, fieldAttribute))));
3802+
}
3803+
}
3804+
3805+
public void testInvalidValidRerankQuery() {
3806+
assertError("""
3807+
FROM books METADATA _score
3808+
| RERANK rerank_score = 42 ON title WITH { "inference_id" : "reranking-inference-id" }
3809+
""", "mapping-books.json", new QueryParams(), "query must be a valid string in RERANK, found [42]");
3810+
}
3811+
37433812
public void testResolveCompletionInferenceId() {
37443813
LogicalPlan plan = analyze("""
37453814
FROM books METADATA _score

0 commit comments

Comments
 (0)