Skip to content

Commit e3e9792

Browse files
author
afoucret
committed
Read completion and rerank settings during parsing.
1 parent 56c1a1e commit e3e9792

File tree

15 files changed

+267
-28
lines changed

15 files changed

+267
-28
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
565565
mock(ProjectResolver.class),
566566
mock(IndexNameExpressionResolver.class),
567567
null,
568-
new InferenceService(mock(Client.class)),
568+
new InferenceService(mock(Client.class), Settings.EMPTY),
569569
new BlockFactoryProvider(PlannerUtils.NON_BREAKING_BLOCK_FACTORY),
570570
TEST_PLANNER_SETTINGS,
571571
new CrossProjectModeDecider(Settings.EMPTY)

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.inference;
99

1010
import org.elasticsearch.client.internal.Client;
11+
import org.elasticsearch.common.settings.Settings;
1112
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
1213
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
1314
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
@@ -17,18 +18,25 @@ public class InferenceService {
1718

1819
private final BulkInferenceRunner.Factory bulkInferenceRunnerFactory;
1920

21+
private final InferenceSettings inferenceSettings;
22+
2023
/**
2124
* Creates a new inference service with the given client.
2225
*
2326
* @param client the Elasticsearch client for inference operations
2427
*/
25-
public InferenceService(Client client) {
26-
this(InferenceResolver.factory(client), BulkInferenceRunner.factory(client));
28+
public InferenceService(Client client, Settings settings) {
29+
this(InferenceResolver.factory(client), BulkInferenceRunner.factory(client), settings);
2730
}
2831

29-
private InferenceService(InferenceResolver.Factory inferenceResolverFactory, BulkInferenceRunner.Factory bulkInferenceRunnerFactory) {
32+
private InferenceService(
33+
InferenceResolver.Factory inferenceResolverFactory,
34+
BulkInferenceRunner.Factory bulkInferenceRunnerFactory,
35+
Settings settings
36+
) {
3037
this.inferenceResolverFactory = inferenceResolverFactory;
3138
this.bulkInferenceRunnerFactory = bulkInferenceRunnerFactory;
39+
this.inferenceSettings = InferenceSettings.fromSettings(settings);
3240
}
3341

3442
/**
@@ -42,6 +50,15 @@ public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry
4250
return inferenceResolverFactory.create(functionRegistry);
4351
}
4452

53+
/**
54+
* Returns the inference configuration settings.
55+
*
56+
* @return the inference settings
57+
*/
58+
public InferenceSettings inferenceSettings() {
59+
return inferenceSettings;
60+
}
61+
4562
public BulkInferenceRunner bulkInferenceRunner() {
4663
return bulkInferenceRunner(BulkInferenceRunnerConfig.DEFAULT);
4764
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.common.settings.Setting;
11+
import org.elasticsearch.common.settings.Settings;
12+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
13+
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
14+
15+
import java.util.Locale;
16+
import java.util.Map;
17+
import java.util.Objects;
18+
19+
public class InferenceSettings {
20+
21+
public static final Setting<Boolean> COMPLETION_ENABLED_SETTING = CommandSettings.commandEnabledSetting("completion");
22+
public static final Setting<Integer> COMPLETION_ROW_LIMIT_SETTING = CommandSettings.rowLimitSetting(
23+
"completion",
24+
Completion.DEFAULT_MAX_ROW_LIMIT
25+
);
26+
27+
public static final Setting<Boolean> RERANK_ENABLED_SETTING = CommandSettings.commandEnabledSetting("rerank");
28+
public static final Setting<Integer> RERANK_ROW_LIMIT_SETTING = CommandSettings.rowLimitSetting("rerank", Rerank.DEFAULT_MAX_ROW_LIMIT);
29+
30+
public static CommandSettings completionCommandConfig(Settings settings) {
31+
return new CommandSettings(COMPLETION_ENABLED_SETTING.get(settings), COMPLETION_ROW_LIMIT_SETTING.get(settings));
32+
}
33+
34+
public static CommandSettings rerankCommandConfig(Settings settings) {
35+
return new CommandSettings(RERANK_ENABLED_SETTING.get(settings), RERANK_ROW_LIMIT_SETTING.get(settings));
36+
}
37+
38+
private final Map<String, CommandSettings> commandSettings;
39+
40+
public static InferenceSettings fromSettings(Settings settings) {
41+
return new InferenceSettings(CommandSettings.fromSettings(settings));
42+
}
43+
44+
private InferenceSettings(Map<String, CommandSettings> commandSettings) {
45+
this.commandSettings = commandSettings;
46+
}
47+
48+
public CommandSettings commandSettings(String commandName) {
49+
return commandSettings.get(commandName);
50+
}
51+
52+
@Override
53+
public boolean equals(Object o) {
54+
if (this == o) return true;
55+
if (o == null || getClass() != o.getClass()) return false;
56+
InferenceSettings that = (InferenceSettings) o;
57+
return Objects.equals(commandSettings, that.commandSettings);
58+
}
59+
60+
@Override
61+
public int hashCode() {
62+
return Objects.hash(commandSettings);
63+
}
64+
65+
public record CommandSettings(boolean enabled, int rowLimit) {
66+
67+
private static final String ENABLED_SETTING_PATTERN = "inference.command.%s.enabled";
68+
private static final String ROW_LIMIT_PATTERN = "inference.command.%s.row_limit";
69+
70+
public static Map<String, CommandSettings> fromSettings(Settings settings) {
71+
return Map.ofEntries(
72+
Map.entry("completion", completionCommandConfig(settings)),
73+
Map.entry("rerank", rerankCommandConfig(settings))
74+
);
75+
}
76+
77+
private static Setting<Boolean> commandEnabledSetting(String commandName) {
78+
return Setting.boolSetting(
79+
String.format(Locale.ROOT, ENABLED_SETTING_PATTERN, commandName),
80+
true,
81+
Setting.Property.NodeScope,
82+
Setting.Property.Dynamic
83+
);
84+
}
85+
86+
private static Setting<Integer> rowLimitSetting(String commandName, int defaultValue) {
87+
return Setting.intSetting(
88+
String.format(Locale.ROOT, ROW_LIMIT_PATTERN, commandName),
89+
defaultValue,
90+
-1,
91+
Setting.Property.NodeScope,
92+
Setting.Property.Dynamic
93+
);
94+
}
95+
}
96+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
import org.antlr.v4.runtime.TokenSource;
1717
import org.antlr.v4.runtime.VocabularyImpl;
1818
import org.antlr.v4.runtime.atn.PredictionMode;
19+
import org.elasticsearch.common.settings.Settings;
1920
import org.elasticsearch.logging.LogManager;
2021
import org.elasticsearch.logging.Logger;
2122
import org.elasticsearch.xpack.esql.core.util.StringUtils;
2223
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
24+
import org.elasticsearch.xpack.esql.inference.InferenceSettings;
2325
import org.elasticsearch.xpack.esql.plan.EsqlStatement;
2426
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2527
import org.elasticsearch.xpack.esql.telemetry.PlanTelemetry;
@@ -105,14 +107,19 @@ public LogicalPlan createStatement(String query) {
105107

106108
// testing utility
107109
public LogicalPlan createStatement(String query, QueryParams params) {
108-
return createStatement(query, params, new PlanTelemetry(new EsqlFunctionRegistry()));
110+
return createStatement(
111+
query,
112+
params,
113+
new PlanTelemetry(new EsqlFunctionRegistry()),
114+
InferenceSettings.fromSettings(Settings.EMPTY)
115+
);
109116
}
110117

111-
public LogicalPlan createStatement(String query, QueryParams params, PlanTelemetry metrics) {
118+
public LogicalPlan createStatement(String query, QueryParams params, PlanTelemetry metrics, InferenceSettings inferenceSettings) {
112119
if (log.isDebugEnabled()) {
113120
log.debug("Parsing as statement: {}", query);
114121
}
115-
return invokeParser(query, params, metrics, EsqlBaseParser::singleStatement, AstBuilder::plan);
122+
return invokeParser(query, params, metrics, inferenceSettings, EsqlBaseParser::singleStatement, AstBuilder::plan);
116123
}
117124

118125
// testing utility
@@ -122,20 +129,21 @@ public EsqlStatement createQuery(String query) {
122129

123130
// testing utility
124131
public EsqlStatement createQuery(String query, QueryParams params) {
125-
return createQuery(query, params, new PlanTelemetry(new EsqlFunctionRegistry()));
132+
return createQuery(query, params, new PlanTelemetry(new EsqlFunctionRegistry()), InferenceSettings.fromSettings(Settings.EMPTY));
126133
}
127134

128-
public EsqlStatement createQuery(String query, QueryParams params, PlanTelemetry metrics) {
135+
public EsqlStatement createQuery(String query, QueryParams params, PlanTelemetry metrics, InferenceSettings inferenceSettings) {
129136
if (log.isDebugEnabled()) {
130137
log.debug("Parsing as statement: {}", query);
131138
}
132-
return invokeParser(query, params, metrics, EsqlBaseParser::statements, AstBuilder::statement);
139+
return invokeParser(query, params, metrics, inferenceSettings, EsqlBaseParser::statements, AstBuilder::statement);
133140
}
134141

135142
private <T> T invokeParser(
136143
String query,
137144
QueryParams params,
138145
PlanTelemetry metrics,
146+
InferenceSettings inferenceSettings,
139147
Function<EsqlBaseParser, ParserRuleContext> parseFunction,
140148
BiFunction<AstBuilder, ParserRuleContext, T> result
141149
) {
@@ -169,7 +177,7 @@ private <T> T invokeParser(
169177
log.trace("Parse tree: {}", tree.toStringTree());
170178
}
171179

172-
return result.apply(new AstBuilder(new ExpressionBuilder.ParsingContext(params, metrics)), tree);
180+
return result.apply(new AstBuilder(new ExpressionBuilder.ParsingContext(params, metrics, inferenceSettings)), tree);
173181
} catch (StackOverflowError e) {
174182
throw new ParsingException("ESQL statement is too large, causing stack overflow when generating the parsing tree: [{}]", query);
175183
// likely thrown by an invalid popMode (such as extra closing parenthesis)

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InsensitiveEquals;
6868
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
6969
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
70+
import org.elasticsearch.xpack.esql.inference.InferenceSettings;
7071
import org.elasticsearch.xpack.esql.telemetry.PlanTelemetry;
7172
import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter;
7273

@@ -124,7 +125,7 @@ public abstract class ExpressionBuilder extends IdentifierBuilder {
124125

125126
protected final ParsingContext context;
126127

127-
public record ParsingContext(QueryParams params, PlanTelemetry telemetry) {}
128+
public record ParsingContext(QueryParams params, PlanTelemetry telemetry, InferenceSettings inferenceSettings) {}
128129

129130
ExpressionBuilder(ParsingContext context) {
130131
this.context = context;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
5656
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
5757
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
58+
import org.elasticsearch.xpack.esql.inference.InferenceSettings;
5859
import org.elasticsearch.xpack.esql.parser.promql.PromqlParserUtils;
5960
import org.elasticsearch.xpack.esql.plan.EsqlStatement;
6061
import org.elasticsearch.xpack.esql.plan.IndexPattern;
@@ -1112,6 +1113,12 @@ private MapExpression visitFuseOptions(List<EsqlBaseParser.FuseConfigurationCont
11121113
@Override
11131114
public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
11141115
Source source = source(ctx);
1116+
1117+
InferenceSettings.CommandSettings commandSettings = inferenceCommandSettings("rerank");
1118+
if (commandSettings.enabled() == false) {
1119+
throw new ParsingException(source(ctx), "RERANK command is disabled. Enable it in the inference settings to use it.");
1120+
}
1121+
11151122
List<Alias> rerankFields = visitRerankFields(ctx.rerankFields());
11161123
Expression queryText = expression(ctx.queryText);
11171124
Attribute scoreAttribute = visitQualifiedName(ctx.targetField, new UnresolvedAttribute(source, MetadataAttribute.SCORE));
@@ -1121,7 +1128,8 @@ public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
11211128

11221129
return p -> {
11231130
checkForRemoteClusters(p, source, "RERANK");
1124-
return applyRerankOptions(new Rerank(source, p, queryText, rerankFields, scoreAttribute), ctx.commandNamedParameters());
1131+
return applyRerankOptions(new Rerank(source, p, queryText, rerankFields, scoreAttribute), ctx.commandNamedParameters())
1132+
.withMaxRows(commandSettings.rowLimit());
11251133
};
11261134
}
11271135

@@ -1153,19 +1161,32 @@ private Rerank applyRerankOptions(Rerank rerank, EsqlBaseParser.CommandNamedPara
11531161

11541162
public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) {
11551163
Source source = source(ctx);
1164+
1165+
InferenceSettings.CommandSettings commandSettings = inferenceCommandSettings("completion");
1166+
if (commandSettings.enabled() == false) {
1167+
throw new ParsingException(source(ctx), "COMPLETION command is disabled. Enable it in the inference settings to use it.");
1168+
}
1169+
11561170
Expression prompt = expression(ctx.prompt);
11571171
Attribute targetField = visitQualifiedName(ctx.targetField, new UnresolvedAttribute(source, Completion.DEFAULT_OUTPUT_FIELD_NAME));
11581172

11591173
if (targetField.qualifier() != null) {
11601174
throw qualifiersUnsupportedInFieldDefinitions(targetField.source(), ctx.targetField.getText());
11611175
}
11621176

1177+
11631178
return p -> {
11641179
checkForRemoteClusters(p, source, "COMPLETION");
1165-
return applyCompletionOptions(new Completion(source, p, prompt, targetField), ctx.commandNamedParameters());
1180+
return applyCompletionOptions(new Completion(source, p, prompt, targetField), ctx.commandNamedParameters()).withMaxRows(
1181+
commandSettings.rowLimit()
1182+
);
11661183
};
11671184
}
11681185

1186+
private InferenceSettings.CommandSettings inferenceCommandSettings(String commandName) {
1187+
return context.inferenceSettings().commandSettings(commandName);
1188+
}
1189+
11691190
private Completion applyCompletionOptions(Completion completion, EsqlBaseParser.CommandNamedParametersContext ctx) {
11701191
MapExpression optionsExpression = ctx == null ? null : visitCommandNamedParameters(ctx);
11711192

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/RowLimited.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
* <p>
2020
* Practically it means that a LIMIT to the plan children.
2121
*/
22-
public interface RowLimited extends SurrogateLogicalPlan {
22+
public interface RowLimited<PlanType extends LogicalPlan> extends SurrogateLogicalPlan {
2323
/**
2424
* Returns the maximum number of rows this plan can produce.
2525
*/
@@ -28,12 +28,12 @@ public interface RowLimited extends SurrogateLogicalPlan {
2828
/**
2929
* Sets the maximum number of rows this plan can produce
3030
*/
31-
default RowLimited withMaxRows(int maxRows) {
31+
default PlanType withMaxRows(int maxRows) {
3232
return withMaxRows(Literal.integer(Source.EMPTY, maxRows));
3333
}
3434

3535
/**
3636
* Sets the maximum number of rows this plan can produce
3737
*/
38-
RowLimited withMaxRows(Expression maxRows);
38+
PlanType withMaxRows(Expression maxRows);
3939
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.elasticsearch.xpack.esql.core.type.DataType;
2525
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2626
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
27-
import org.elasticsearch.xpack.esql.plan.logical.RowLimited;
2827

2928
import java.io.IOException;
3029
import java.util.List;
@@ -94,7 +93,7 @@ public Attribute targetField() {
9493
}
9594

9695
@Override
97-
public RowLimited withMaxRows(Expression rowLimit) {
96+
public Completion withMaxRows(Expression rowLimit) {
9897
return new Completion(source(), child(), inferenceId(), rowLimit, prompt, targetField);
9998
}
10099

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ public abstract class InferencePlan<PlanType extends InferencePlan<PlanType>> ex
3232
SortAgnostic,
3333
GeneratingPlan<InferencePlan<PlanType>>,
3434
ExecutesOn.Coordinator,
35-
RowLimited {
35+
RowLimited<PlanType> {
3636

37-
protected static final TransportVersion ESQL_INFERENCE_USAGE_LIMIT = TransportVersion.fromName("esql_inference_usage_limit");
37+
public static final TransportVersion ESQL_INFERENCE_USAGE_LIMIT = TransportVersion.fromName("esql_inference_usage_limit");
3838

3939
public static final String INFERENCE_ID_OPTION_NAME = "inference_id";
4040
public static final List<String> VALID_INFERENCE_OPTION_NAMES = List.of(INFERENCE_ID_OPTION_NAME);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2828
import org.elasticsearch.xpack.esql.plan.logical.Eval;
2929
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
30-
import org.elasticsearch.xpack.esql.plan.logical.RowLimited;
3130
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
3231

3332
import java.io.IOException;
@@ -116,8 +115,8 @@ public TaskType taskType() {
116115
}
117116

118117
@Override
119-
public RowLimited withMaxRows(Expression rowLimit) {
120-
return new Rerank(source(), child(), inferenceId(), rowLimit(), queryText, rerankFields, scoreAttribute);
118+
public Rerank withMaxRows(Expression rowLimit) {
119+
return new Rerank(source(), child(), inferenceId(), rowLimit, queryText, rerankFields, scoreAttribute);
121120
}
122121

123122
@Override

0 commit comments

Comments
 (0)