Skip to content

Commit 5ce0d87

Browse files
committed
Adding parser tests for inference limits.
1 parent 83ec2b3 commit 5ce0d87

File tree

4 files changed

+83
-3
lines changed

4 files changed

+83
-3
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
/**
1414
* Settings for inference features such as completion and rerank.
1515
*/
16-
public record InferenceSettings(boolean completionEnabled, int completionMaxSize, boolean rerankEnabled, int rerankMaxSize) {
16+
public record InferenceSettings(boolean completionEnabled, int completionRowLimit, boolean rerankEnabled, int rerankRowLimit) {
1717

1818
public static final Setting<Boolean> COMPLETION_ENABLED_SETTING = Setting.boolSetting(
1919
"esql.command.completion.enabled",

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,13 +1112,18 @@ private MapExpression visitFuseOptions(List<EsqlBaseParser.FuseConfigurationCont
11121112
@Override
11131113
public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
11141114
Source source = source(ctx);
1115+
1116+
if (context.inferenceSettings().rerankEnabled() == false) {
1117+
throw new ParsingException(source, "RERANK command is disabled in settings.");
1118+
}
1119+
11151120
List<Alias> rerankFields = visitRerankFields(ctx.rerankFields());
11161121
Expression queryText = expression(ctx.queryText);
11171122
Attribute scoreAttribute = visitQualifiedName(ctx.targetField, new UnresolvedAttribute(source, MetadataAttribute.SCORE));
11181123
if (scoreAttribute.qualifier() != null) {
11191124
throw qualifiersUnsupportedInFieldDefinitions(scoreAttribute.source(), ctx.targetField.getText());
11201125
}
1121-
Literal rowLimit = Literal.integer(source, 1000);
1126+
Literal rowLimit = Literal.integer(source, context.inferenceSettings().rerankRowLimit());
11221127

11231128
return p -> {
11241129
checkForRemoteClusters(p, source, "RERANK");
@@ -1157,14 +1162,19 @@ private Rerank applyRerankOptions(Rerank rerank, EsqlBaseParser.CommandNamedPara
11571162

11581163
public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) {
11591164
Source source = source(ctx);
1165+
1166+
if (context.inferenceSettings().completionEnabled() == false) {
1167+
throw new ParsingException(source, "COMPLETION command is disabled in settings.");
1168+
}
1169+
11601170
Expression prompt = expression(ctx.prompt);
11611171
Attribute targetField = visitQualifiedName(ctx.targetField, new UnresolvedAttribute(source, Completion.DEFAULT_OUTPUT_FIELD_NAME));
11621172

11631173
if (targetField.qualifier() != null) {
11641174
throw qualifiersUnsupportedInFieldDefinitions(targetField.source(), ctx.targetField.getText());
11651175
}
11661176

1167-
Literal rowLimit = Literal.integer(source, 100);
1177+
Literal rowLimit = Literal.integer(source, context.inferenceSettings().completionRowLimit());
11681178

11691179
return p -> {
11701180
checkForRemoteClusters(p, source, "COMPLETION");

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.common.logging.LoggerMessageFormat;
1111
import org.elasticsearch.common.lucene.BytesRefs;
12+
import org.elasticsearch.common.settings.Settings;
1213
import org.elasticsearch.index.IndexMode;
1314
import org.elasticsearch.test.ESTestCase;
1415
import org.elasticsearch.xpack.esql.VerificationException;
@@ -17,11 +18,14 @@
1718
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
1819
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
1920
import org.elasticsearch.xpack.esql.core.type.DataType;
21+
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
2022
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
23+
import org.elasticsearch.xpack.esql.inference.InferenceSettings;
2124
import org.elasticsearch.xpack.esql.plan.EsqlStatement;
2225
import org.elasticsearch.xpack.esql.plan.IndexPattern;
2326
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2427
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
28+
import org.elasticsearch.xpack.esql.telemetry.PlanTelemetry;
2529

2630
import java.math.BigInteger;
2731
import java.util.ArrayList;
@@ -69,6 +73,10 @@ LogicalPlan processingCommand(String e) {
6973
return parser.createStatement("row a = 1 | " + e);
7074
}
7175

76+
LogicalPlan processingCommand(String e, QueryParams params, Settings settings) {
77+
return parser.createStatement("row a = 1 | " + e, params, new PlanTelemetry(new EsqlFunctionRegistry()), new InferenceSettings(settings));
78+
}
79+
7280
static UnresolvedAttribute attribute(String name) {
7381
return new UnresolvedAttribute(EMPTY, name);
7482
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.Build;
1212
import org.elasticsearch.common.logging.LoggerMessageFormat;
1313
import org.elasticsearch.common.lucene.BytesRefs;
14+
import org.elasticsearch.common.settings.Settings;
1415
import org.elasticsearch.core.PathUtils;
1516
import org.elasticsearch.core.Tuple;
1617
import org.elasticsearch.index.IndexMode;
@@ -48,6 +49,7 @@
4849
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
4950
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
5051
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
52+
import org.elasticsearch.xpack.esql.inference.InferenceSettings;
5153
import org.elasticsearch.xpack.esql.plan.IndexPattern;
5254
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
5355
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
@@ -3650,6 +3652,7 @@ public void testRerankDefaultInferenceIdAndScoreAttribute() {
36503652
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
36513653
assertThat(rerank.queryText(), equalTo(literalString("query text")));
36523654
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
3655+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36533656
}
36543657

36553658
public void testRerankEmptyOptions() {
@@ -3660,6 +3663,7 @@ public void testRerankEmptyOptions() {
36603663
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
36613664
assertThat(rerank.queryText(), equalTo(literalString("query text")));
36623665
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
3666+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36633667
}
36643668

36653669
public void testRerankInferenceId() {
@@ -3670,6 +3674,7 @@ public void testRerankInferenceId() {
36703674
assertThat(rerank.queryText(), equalTo(literalString("query text")));
36713675
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
36723676
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
3677+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36733678
}
36743679

36753680
public void testRerankScoreAttribute() {
@@ -3680,6 +3685,7 @@ public void testRerankScoreAttribute() {
36803685
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("rerank_score")));
36813686
assertThat(rerank.queryText(), equalTo(literalString("query text")));
36823687
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
3688+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36833689
}
36843690

36853691
public void testRerankInferenceIdAnddScoreAttribute() {
@@ -3690,6 +3696,7 @@ public void testRerankInferenceIdAnddScoreAttribute() {
36903696
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("rerank_score")));
36913697
assertThat(rerank.queryText(), equalTo(literalString("query text")));
36923698
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
3699+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
36933700
}
36943701

36953702
public void testRerankSingleField() {
@@ -3700,6 +3707,7 @@ public void testRerankSingleField() {
37003707
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
37013708
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
37023709
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
3710+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
37033711
}
37043712

37053713
public void testRerankMultipleFields() {
@@ -3721,6 +3729,7 @@ public void testRerankMultipleFields() {
37213729
)
37223730
);
37233731
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
3732+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
37243733
}
37253734

37263735
public void testRerankComputedFields() {
@@ -3741,6 +3750,7 @@ public void testRerankComputedFields() {
37413750
)
37423751
);
37433752
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("_score")));
3753+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
37443754
}
37453755

37463756
public void testRerankComputedFieldsWithoutName() {
@@ -3762,6 +3772,7 @@ public void testRerankWithPositionalParameters() {
37623772
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
37633773
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
37643774
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("rerank_score")));
3775+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
37653776
}
37663777

37673778
public void testRerankWithNamedParameters() {
@@ -3778,6 +3789,29 @@ public void testRerankWithNamedParameters() {
37783789
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
37793790
assertThat(rerank.rerankFields(), equalToIgnoringIds(List.of(alias("title", attribute("title")))));
37803791
assertThat(rerank.scoreAttribute(), equalToIgnoringIds(attribute("rerank_score")));
3792+
assertThat(rerank.rowLimit(), equalTo(integer(1_000)));
3793+
}
3794+
3795+
public void testRerankRowLimitOverride() {
3796+
int customRowLimit = between(1, 10_000);
3797+
Settings settings = Settings.builder().put(InferenceSettings.RERANK_ROW_LIMIT_SETTING.getKey(), customRowLimit).build();
3798+
3799+
var plan = as(
3800+
processingCommand("RERANK \"query text\" ON title WITH { \"inference_id\" : \"inferenceID\" }", new QueryParams(), settings),
3801+
Rerank.class
3802+
);
3803+
3804+
assertThat(plan.rowLimit(), equalTo(Literal.integer(EMPTY, customRowLimit)));
3805+
}
3806+
3807+
public void testRerankCommandDisabled() {
3808+
Settings settings = Settings.builder().put(InferenceSettings.RERANK_ENABLED_SETTING.getKey(), false).build();
3809+
3810+
ParsingException pe = expectThrows(
3811+
ParsingException.class,
3812+
() -> processingCommand("RERANK \"query text\" ON title", new QueryParams(), settings)
3813+
);
3814+
assertThat(pe.getMessage(), containsString("RERANK command is disabled"));
37813815
}
37823816

37833817
public void testInvalidRerank() {
@@ -3824,6 +3858,8 @@ public void testCompletionUsingFieldAsPrompt() {
38243858
assertThat(plan.prompt(), equalToIgnoringIds(attribute("prompt_field")));
38253859
assertThat(plan.inferenceId(), equalTo(literalString("inferenceID")));
38263860
assertThat(plan.targetField(), equalToIgnoringIds(attribute("targetField")));
3861+
assertThat(plan.rowLimit(), equalTo(integer(100)));
3862+
38273863
}
38283864

38293865
public void testCompletionUsingFunctionAsPrompt() {
@@ -3835,6 +3871,7 @@ public void testCompletionUsingFunctionAsPrompt() {
38353871
assertThat(plan.prompt(), equalToIgnoringIds(function("CONCAT", List.of(attribute("fieldA"), attribute("fieldB")))));
38363872
assertThat(plan.inferenceId(), equalTo(literalString("inferenceID")));
38373873
assertThat(plan.targetField(), equalToIgnoringIds(attribute("targetField")));
3874+
assertThat(plan.rowLimit(), equalTo(integer(100)));
38383875
}
38393876

38403877
public void testCompletionDefaultFieldName() {
@@ -3843,6 +3880,7 @@ public void testCompletionDefaultFieldName() {
38433880
assertThat(plan.prompt(), equalToIgnoringIds(attribute("prompt_field")));
38443881
assertThat(plan.inferenceId(), equalTo(literalString("inferenceID")));
38453882
assertThat(plan.targetField(), equalToIgnoringIds(attribute("completion")));
3883+
assertThat(plan.rowLimit(), equalTo(integer(100)));
38463884
}
38473885

38483886
public void testCompletionWithPositionalParameters() {
@@ -3855,6 +3893,7 @@ public void testCompletionWithPositionalParameters() {
38553893
assertThat(plan.prompt(), equalToIgnoringIds(attribute("prompt_field")));
38563894
assertThat(plan.inferenceId(), equalTo(literalString("inferenceId")));
38573895
assertThat(plan.targetField(), equalToIgnoringIds(attribute("completion")));
3896+
assertThat(plan.rowLimit(), equalTo(integer(100)));
38583897
}
38593898

38603899
public void testCompletionWithNamedParameters() {
@@ -3867,6 +3906,29 @@ public void testCompletionWithNamedParameters() {
38673906
assertThat(plan.prompt(), equalToIgnoringIds(attribute("prompt_field")));
38683907
assertThat(plan.inferenceId(), equalTo(literalString("myInference")));
38693908
assertThat(plan.targetField(), equalToIgnoringIds(attribute("completion")));
3909+
assertThat(plan.rowLimit(), equalTo(integer(100)));
3910+
}
3911+
3912+
public void testCompletionRowLimitOverride() {
3913+
int customRowLimit = between(1, 10_000);
3914+
Settings settings = Settings.builder().put(InferenceSettings.COMPLETION_ROW_LIMIT_SETTING.getKey(), customRowLimit).build();
3915+
3916+
var plan = as(
3917+
processingCommand("COMPLETION prompt_field WITH{ \"inference_id\" : \"inferenceID\" }", new QueryParams(), settings),
3918+
Completion.class
3919+
);
3920+
3921+
assertThat(plan.rowLimit(), equalTo(Literal.integer(EMPTY, customRowLimit)));
3922+
}
3923+
3924+
public void testCompletionCommandDisabled() {
3925+
Settings settings = Settings.builder().put(InferenceSettings.COMPLETION_ENABLED_SETTING.getKey(), false).build();
3926+
3927+
ParsingException pe = expectThrows(
3928+
ParsingException.class,
3929+
() -> processingCommand("COMPLETION prompt_field WITH{ \"inference_id\" : \"inferenceID\" }", new QueryParams(), settings)
3930+
);
3931+
assertThat(pe.getMessage(), containsString("COMPLETION command is disabled"));
38703932
}
38713933

38723934
public void testInvalidCompletion() {

0 commit comments

Comments
 (0)