Skip to content

Commit 3e88a16

Browse files
author
afoucret
committed
Adding parser tests for inference limits.
1 parent ef57a6a commit 3e88a16

File tree

4 files changed

+83
-3
lines changed

4 files changed

+83
-3
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
/**
1414
* Settings for inference features such as completion and rerank.
1515
*/
16-
public record InferenceSettings(boolean completionEnabled, int completionMaxSize, boolean rerankEnabled, int rerankMaxSize) {
16+
public record InferenceSettings(boolean completionEnabled, int completionRowLimit, boolean rerankEnabled, int rerankRowLimit) {
1717

1818
public static final Setting<Boolean> COMPLETION_ENABLED_SETTING = Setting.boolSetting(
1919
"esql.command.completion.enabled",

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,13 +1112,18 @@ private MapExpression visitFuseOptions(List<EsqlBaseParser.FuseConfigurationCont
11121112
@Override
11131113
public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
11141114
Source source = source(ctx);
1115+
1116+
if (context.inferenceSettings().rerankEnabled() == false) {
1117+
throw new ParsingException(source, "RERANK command is disabled in settings.");
1118+
}
1119+
11151120
List<Alias> rerankFields = visitRerankFields(ctx.rerankFields());
11161121
Expression queryText = expression(ctx.queryText);
11171122
Attribute scoreAttribute = visitQualifiedName(ctx.targetField, new UnresolvedAttribute(source, MetadataAttribute.SCORE));
11181123
if (scoreAttribute.qualifier() != null) {
11191124
throw qualifiersUnsupportedInFieldDefinitions(scoreAttribute.source(), ctx.targetField.getText());
11201125
}
1121-
Literal rowLimit = Literal.integer(source, 1000);
1126+
Literal rowLimit = Literal.integer(source, context.inferenceSettings().rerankRowLimit());
11221127

11231128
return p -> {
11241129
checkForRemoteClusters(p, source, "RERANK");
@@ -1157,14 +1162,19 @@ private Rerank applyRerankOptions(Rerank rerank, EsqlBaseParser.CommandNamedPara
11571162

11581163
public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) {
11591164
Source source = source(ctx);
1165+
1166+
if (context.inferenceSettings().completionEnabled() == false) {
1167+
throw new ParsingException(source, "COMPLETION command is disabled in settings.");
1168+
}
1169+
11601170
Expression prompt = expression(ctx.prompt);
11611171
Attribute targetField = visitQualifiedName(ctx.targetField, new UnresolvedAttribute(source, Completion.DEFAULT_OUTPUT_FIELD_NAME));
11621172

11631173
if (targetField.qualifier() != null) {
11641174
throw qualifiersUnsupportedInFieldDefinitions(targetField.source(), ctx.targetField.getText());
11651175
}
11661176

1167-
Literal rowLimit = Literal.integer(source, 100);
1177+
Literal rowLimit = Literal.integer(source, context.inferenceSettings().completionRowLimit());
11681178

11691179
return p -> {
11701180
checkForRemoteClusters(p, source, "COMPLETION");

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.common.logging.LoggerMessageFormat;
1111
import org.elasticsearch.common.lucene.BytesRefs;
12+
import org.elasticsearch.common.settings.Settings;
1213
import org.elasticsearch.index.IndexMode;
1314
import org.elasticsearch.test.ESTestCase;
1415
import org.elasticsearch.xpack.esql.VerificationException;
@@ -17,11 +18,14 @@
1718
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
1819
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
1920
import org.elasticsearch.xpack.esql.core.type.DataType;
21+
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
2022
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
23+
import org.elasticsearch.xpack.esql.inference.InferenceSettings;
2124
import org.elasticsearch.xpack.esql.plan.EsqlStatement;
2225
import org.elasticsearch.xpack.esql.plan.IndexPattern;
2326
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2427
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
28+
import org.elasticsearch.xpack.esql.telemetry.PlanTelemetry;
2529

2630
import java.math.BigInteger;
2731
import java.util.ArrayList;
@@ -69,6 +73,10 @@ LogicalPlan processingCommand(String e) {
6973
return parser.parseQuery("row a = 1 | " + e);
7074
}
7175

76+
LogicalPlan processingCommand(String e, QueryParams params, Settings settings) {
77+
return parser.createStatement("row a = 1 | " + e, params, new PlanTelemetry(new EsqlFunctionRegistry()), new InferenceSettings(settings));
78+
}
79+
7280
static UnresolvedAttribute attribute(String name) {
7381
return new UnresolvedAttribute(EMPTY, name);
7482
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.Build;
1212
import org.elasticsearch.common.logging.LoggerMessageFormat;
1313
import org.elasticsearch.common.lucene.BytesRefs;
14+
import org.elasticsearch.common.settings.Settings;
1415
import org.elasticsearch.core.PathUtils;
1516
import org.elasticsearch.core.Tuple;
1617
import org.elasticsearch.index.IndexMode;
@@ -48,6 +49,7 @@
4849
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
4950
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
5051
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
52+
import org.elasticsearch.xpack.esql.inference.InferenceSettings;
5153
import org.elasticsearch.xpack.esql.plan.IndexPattern;
5254
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
5355
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
@@ -3643,6 +3645,7 @@ public void testRerankDefaultInferenceIdAndScoreAttribute() {
36433645
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
36443646
assertThat(rerank.queryText(), equalTo(literalString("statement text")));
36453647
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
3648+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36463649
}
36473650

36483651
public void testRerankEmptyOptions() {
@@ -3653,6 +3656,7 @@ public void testRerankEmptyOptions() {
36533656
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
36543657
assertThat(rerank.queryText(), equalTo(literalString("statement text")));
36553658
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
3659+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36563660
}
36573661

36583662
public void testRerankInferenceId() {
@@ -3663,6 +3667,7 @@ public void testRerankInferenceId() {
36633667
assertThat(rerank.queryText(), equalTo(literalString("statement text")));
36643668
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
36653669
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
3670+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36663671
}
36673672

36683673
public void testRerankScoreAttribute() {
@@ -3673,6 +3678,7 @@ public void testRerankScoreAttribute() {
36733678
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("rerank_score")));
36743679
assertThat(rerank.queryText(), equalTo(literalString("statement text")));
36753680
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
3681+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36763682
}
36773683

36783684
public void testRerankInferenceIdAnddScoreAttribute() {
@@ -3683,6 +3689,7 @@ public void testRerankInferenceIdAnddScoreAttribute() {
36833689
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("rerank_score")));
36843690
assertThat(rerank.queryText(), equalTo(literalString("statement text")));
36853691
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
3692+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36863693
}
36873694

36883695
public void testRerankSingleField() {
@@ -3693,6 +3700,7 @@ public void testRerankSingleField() {
36933700
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
36943701
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
36953702
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
3703+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36963704
}
36973705

36983706
public void testRerankMultipleFields() {
@@ -3714,6 +3722,7 @@ public void testRerankMultipleFields() {
37143722
)
37153723
);
37163724
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
3725+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
37173726
}
37183727

37193728
public void testRerankComputedFields() {
@@ -3734,6 +3743,7 @@ public void testRerankComputedFields() {
37343743
)
37353744
);
37363745
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
3746+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
37373747
}
37383748

37393749
public void testRerankComputedFieldsWithoutName() {
@@ -3755,6 +3765,7 @@ public void testRerankWithPositionalParameters() {
37553765
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
37563766
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
37573767
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("rerank_score")));
3768+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
37583769
}
37593770

37603771
public void testRerankWithNamedParameters() {
@@ -3770,6 +3781,29 @@ public void testRerankWithNamedParameters() {
37703781
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
37713782
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
37723783
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("rerank_score")));
3784+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
3785+
}
3786+
3787+
public void testRerankRowLimitOverride() {
3788+
int customRowLimit = between(1, 10_000);
3789+
Settings settings = Settings.builder().put(InferenceSettings.RERANK_ROW_LIMIT_SETTING.getKey(), customRowLimit).build();
3790+
3791+
var plan = as(
3792+
processingCommand("RERANK \"query text\" ON title WITH { \"inference_id\" : \"inferenceID\" }", new QueryParams(), settings),
3793+
Rerank.class
3794+
);
3795+
3796+
assertThat(plan.rowLimit(), equalTo(Literal.integer(EMPTY, customRowLimit)));
3797+
}
3798+
3799+
public void testRerankCommandDisabled() {
3800+
Settings settings = Settings.builder().put(InferenceSettings.RERANK_ENABLED_SETTING.getKey(), false).build();
3801+
3802+
ParsingException pe = expectThrows(
3803+
ParsingException.class,
3804+
() -> processingCommand("RERANK \"query text\" ON title", new QueryParams(), settings)
3805+
);
3806+
assertThat(pe.getMessage(), containsString("RERANK command is disabled"));
37733807
}
37743808

37753809
public void testInvalidRerank() {
@@ -3816,6 +3850,8 @@ public void testCompletionUsingFieldAsPrompt() {
38163850
assertThat(plan.prompt(), equalToIgnoringIds(attribute("prompt_field")));
38173851
assertThat(plan.inferenceId(), equalTo(literalString("inferenceID")));
38183852
assertThat(plan.targetField(), equalToIgnoringIds(attribute("targetField")));
3853+
assertThat(plan.rowLimit(), equalTo(integer(100)));
3854+
38193855
}
38203856

38213857
public void testCompletionUsingFunctionAsPrompt() {
@@ -3827,6 +3863,7 @@ public void testCompletionUsingFunctionAsPrompt() {
38273863
assertThat(plan.prompt(), equalToIgnoringIds(function("CONCAT", List.of(attribute("fieldA"), attribute("fieldB")))));
38283864
assertThat(plan.inferenceId(), equalTo(literalString("inferenceID")));
38293865
assertThat(plan.targetField(), equalToIgnoringIds(attribute("targetField")));
3866+
assertThat(plan.rowLimit(), equalTo(integer(100)));
38303867
}
38313868

38323869
public void testCompletionDefaultFieldName() {
@@ -3835,6 +3872,7 @@ public void testCompletionDefaultFieldName() {
38353872
assertThat(plan.prompt(), equalToIgnoringIds(attribute("prompt_field")));
38363873
assertThat(plan.inferenceId(), equalTo(literalString("inferenceID")));
38373874
assertThat(plan.targetField(), equalToIgnoringIds(attribute("completion")));
3875+
assertThat(plan.rowLimit(), equalTo(integer(100)));
38383876
}
38393877

38403878
public void testCompletionWithPositionalParameters() {
@@ -3847,6 +3885,7 @@ public void testCompletionWithPositionalParameters() {
38473885
assertThat(plan.prompt(), equalToIgnoringIds(attribute("prompt_field")));
38483886
assertThat(plan.inferenceId(), equalTo(literalString("inferenceId")));
38493887
assertThat(plan.targetField(), equalToIgnoringIds(attribute("completion")));
3888+
assertThat(plan.rowLimit(), equalTo(integer(100)));
38503889
}
38513890

38523891
public void testCompletionWithNamedParameters() {
@@ -3859,6 +3898,29 @@ public void testCompletionWithNamedParameters() {
38593898
assertThat(plan.prompt(), equalToIgnoringIds(attribute("prompt_field")));
38603899
assertThat(plan.inferenceId(), equalTo(literalString("myInference")));
38613900
assertThat(plan.targetField(), equalToIgnoringIds(attribute("completion")));
3901+
assertThat(plan.rowLimit(), equalTo(integer(100)));
3902+
}
3903+
3904+
public void testCompletionRowLimitOverride() {
3905+
int customRowLimit = between(1, 10_000);
3906+
Settings settings = Settings.builder().put(InferenceSettings.COMPLETION_ROW_LIMIT_SETTING.getKey(), customRowLimit).build();
3907+
3908+
var plan = as(
3909+
processingCommand("COMPLETION prompt_field WITH{ \"inference_id\" : \"inferenceID\" }", new QueryParams(), settings),
3910+
Completion.class
3911+
);
3912+
3913+
assertThat(plan.rowLimit(), equalTo(Literal.integer(EMPTY, customRowLimit)));
3914+
}
3915+
3916+
public void testCompletionCommandDisabled() {
3917+
Settings settings = Settings.builder().put(InferenceSettings.COMPLETION_ENABLED_SETTING.getKey(), false).build();
3918+
3919+
ParsingException pe = expectThrows(
3920+
ParsingException.class,
3921+
() -> processingCommand("COMPLETION prompt_field WITH{ \"inference_id\" : \"inferenceID\" }", new QueryParams(), settings)
3922+
);
3923+
assertThat(pe.getMessage(), containsString("COMPLETION command is disabled"));
38623924
}
38633925

38643926
public void testInvalidCompletion() {

0 commit comments

Comments
 (0)