Skip to content

Commit 3fb16f2

Browse files
committed
Add row limit to Completion and Rerank plan.
1 parent 2c80f52 commit 3fb16f2

File tree

16 files changed

+191
-45
lines changed

16 files changed

+191
-45
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
index_created_transport_version,9221000
1+
esql_completion_usage_limit,9223000

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ private LogicalPlan resolveCompletion(Completion p, List<Attribute> childrenOutp
600600
prompt = prompt.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
601601
}
602602

603-
return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField);
603+
return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField, p.rowLimit());
604604
}
605605

606606
private LogicalPlan resolveMvExpand(MvExpand p, List<Attribute> childrenOutput) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceCommandConfig.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99

1010
import org.elasticsearch.common.settings.Setting;
1111
import org.elasticsearch.common.settings.Settings;
12+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
13+
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
1214

1315
import java.util.Map;
1416

1517
public record InferenceCommandConfig(boolean enabled, int rowLimit) {
1618

1719
public static final Setting<Boolean> COMPLETION_ENABLED_SETTING = commandEnabledSetting("completion");
18-
public static final Setting<Integer> COMPLETION_ROW_LIMIT_SETTING = rowLimitSetting("completion", 100);
20+
public static final Setting<Integer> COMPLETION_ROW_LIMIT_SETTING = rowLimitSetting("completion", Completion.DEFAULT_MAX_ROW_LIMIT);
1921
public static final Setting<Boolean> RERANK_ENABLED_SETTING = commandEnabledSetting("rerank");
20-
public static final Setting<Integer> RERANK_ROW_LIMIT_SETTING = rowLimitSetting("rerank", 1000);
22+
public static final Setting<Integer> RERANK_ROW_LIMIT_SETTING = rowLimitSetting("rerank", Rerank.DEFAULT_MAX_ROW_LIMIT);
2123

2224
public static InferenceCommandConfig completionCommandConfig(Settings settings) {
2325
return new InferenceCommandConfig(COMPLETION_ENABLED_SETTING.get(settings), COMPLETION_ROW_LIMIT_SETTING.get(settings));
@@ -52,4 +54,5 @@ private static Setting<Integer> rowLimitSetting(String commandName, int defaultV
5254
Setting.Property.Dynamic
5355
);
5456
}
57+
5558
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
@FunctionalInterface
11+
public interface InferenceCommandConfigProvider {
12+
InferenceCommandConfig get(String commandName);
13+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ private <T> T invokeParser(
175175
log.trace("Parse tree: {}", tree.toStringTree());
176176
}
177177

178-
return result.apply(new AstBuilder(new ExpressionBuilder.ParsingContext(params, metrics)), tree);
178+
return result.apply(
179+
new AstBuilder(new ExpressionBuilder.ParsingContext(params, metrics, config::inferenceCommandConfig)),
180+
tree
181+
);
179182
} catch (StackOverflowError e) {
180183
throw new ParsingException("ESQL statement is too large, causing stack overflow when generating the parsing tree: [{}]", query);
181184
// likely thrown by an invalid popMode (such as extra closing parenthesis)

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InsensitiveEquals;
6868
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
6969
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
70+
import org.elasticsearch.xpack.esql.inference.InferenceCommandConfigProvider;
7071
import org.elasticsearch.xpack.esql.telemetry.PlanTelemetry;
7172
import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter;
7273

@@ -124,7 +125,11 @@ public abstract class ExpressionBuilder extends IdentifierBuilder {
124125

125126
protected final ParsingContext context;
126127

127-
public record ParsingContext(QueryParams params, PlanTelemetry telemetry) {}
128+
public record ParsingContext(
129+
QueryParams params,
130+
PlanTelemetry telemetry,
131+
InferenceCommandConfigProvider inferenceCommandConfigProvider
132+
) {}
128133

129134
ExpressionBuilder(ParsingContext context) {
130135
this.context = context;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,9 +1109,14 @@ public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
11091109
throw qualifiersUnsupportedInFieldDefinitions(scoreAttribute.source(), ctx.targetField.getText());
11101110
}
11111111

1112+
Literal rowLimit = Literal.integer(Source.EMPTY, context.inferenceCommandConfigProvider().get("rerank").rowLimit());
1113+
11121114
return p -> {
11131115
checkForRemoteClusters(p, source, "RERANK");
1114-
return applyRerankOptions(new Rerank(source, p, queryText, rerankFields, scoreAttribute), ctx.commandNamedParameters());
1116+
return applyRerankOptions(
1117+
new Rerank(source, p, queryText, rerankFields, scoreAttribute, rowLimit),
1118+
ctx.commandNamedParameters()
1119+
);
11151120
};
11161121
}
11171122

@@ -1150,9 +1155,11 @@ public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContex
11501155
throw qualifiersUnsupportedInFieldDefinitions(targetField.source(), ctx.targetField.getText());
11511156
}
11521157

1158+
Literal rowLimit = Literal.integer(Source.EMPTY, context.inferenceCommandConfigProvider().get("completion").rowLimit());
1159+
11531160
return p -> {
11541161
checkForRemoteClusters(p, source, "COMPLETION");
1155-
return applyCompletionOptions(new Completion(source, p, prompt, targetField), ctx.commandNamedParameters());
1162+
return applyCompletionOptions(new Completion(source, p, prompt, targetField, rowLimit), ctx.commandNamedParameters());
11561163
};
11571164
}
11581165

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.esql.plan.logical.inference;
99

10+
import org.elasticsearch.TransportVersion;
1011
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1112
import org.elasticsearch.common.io.stream.StreamInput;
1213
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -42,18 +43,32 @@ public class Completion extends InferencePlan<Completion> implements TelemetryAw
4243
"Completion",
4344
Completion::new
4445
);
46+
public static final int DEFAULT_MAX_ROW_LIMIT = 100;
47+
48+
private static final TransportVersion ESQL_COMPLETION_USAGE_LIMIT = TransportVersion.fromName("esql_completion_usage_limit");
49+
4550
private final Expression prompt;
4651
private final Attribute targetField;
52+
private final Expression rowLimit;
53+
4754
private List<Attribute> lazyOutput;
4855

49-
public Completion(Source source, LogicalPlan p, Expression prompt, Attribute targetField) {
50-
this(source, p, Literal.keyword(Source.EMPTY, DEFAULT_OUTPUT_FIELD_NAME), prompt, targetField);
56+
public Completion(Source source, LogicalPlan p, Expression prompt, Attribute targetField, Expression rowLimit) {
57+
this(source, p, Literal.keyword(Source.EMPTY, DEFAULT_OUTPUT_FIELD_NAME), prompt, targetField, rowLimit);
5158
}
5259

53-
public Completion(Source source, LogicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
60+
public Completion(
61+
Source source,
62+
LogicalPlan child,
63+
Expression inferenceId,
64+
Expression prompt,
65+
Attribute targetField,
66+
Expression rowLimit
67+
) {
5468
super(source, child, inferenceId);
5569
this.prompt = prompt;
5670
this.targetField = targetField;
71+
this.rowLimit = rowLimit;
5772
}
5873

5974
public Completion(StreamInput in) throws IOException {
@@ -62,7 +77,10 @@ public Completion(StreamInput in) throws IOException {
6277
in.readNamedWriteable(LogicalPlan.class),
6378
in.readNamedWriteable(Expression.class),
6479
in.readNamedWriteable(Expression.class),
65-
in.readNamedWriteable(Attribute.class)
80+
in.readNamedWriteable(Attribute.class),
81+
in.getTransportVersion().supports(ESQL_COMPLETION_USAGE_LIMIT)
82+
? in.readNamedWriteable(Expression.class)
83+
: Literal.integer(Source.EMPTY, DEFAULT_MAX_ROW_LIMIT)
6684
);
6785
}
6886

@@ -71,6 +89,9 @@ public void writeTo(StreamOutput out) throws IOException {
7189
super.writeTo(out);
7290
out.writeNamedWriteable(prompt);
7391
out.writeNamedWriteable(targetField);
92+
if (out.getTransportVersion().supports(ESQL_COMPLETION_USAGE_LIMIT)) {
93+
out.writeNamedWriteable(rowLimit);
94+
}
7495
}
7596

7697
public Expression prompt() {
@@ -81,18 +102,22 @@ public Attribute targetField() {
81102
return targetField;
82103
}
83104

105+
public Expression rowLimit() {
106+
return rowLimit;
107+
}
108+
84109
@Override
85110
public Completion withInferenceId(Expression newInferenceId) {
86111
if (inferenceId().equals(newInferenceId)) {
87112
return this;
88113
}
89114

90-
return new Completion(source(), child(), newInferenceId, prompt, targetField);
115+
return new Completion(source(), child(), newInferenceId, prompt, targetField, rowLimit);
91116
}
92117

93118
@Override
94119
public Completion replaceChild(LogicalPlan newChild) {
95-
return new Completion(source(), newChild, inferenceId(), prompt, targetField);
120+
return new Completion(source(), newChild, inferenceId(), prompt, targetField, rowLimit);
96121
}
97122

98123
@Override
@@ -122,7 +147,7 @@ public List<Attribute> generatedAttributes() {
122147
@Override
123148
public Completion withGeneratedNames(List<String> newNames) {
124149
checkNumberOfNewNames(newNames);
125-
return new Completion(source(), child(), inferenceId(), prompt, this.renameTargetField(newNames.get(0)));
150+
return new Completion(source(), child(), inferenceId(), prompt, this.renameTargetField(newNames.get(0)), rowLimit);
126151
}
127152

128153
private Attribute renameTargetField(String newName) {
@@ -157,7 +182,7 @@ public void postAnalysisVerification(Failures failures) {
157182

158183
@Override
159184
protected NodeInfo<? extends LogicalPlan> info() {
160-
return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField);
185+
return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField, rowLimit);
161186
}
162187

163188
@Override
@@ -167,11 +192,13 @@ public boolean equals(Object o) {
167192
if (super.equals(o) == false) return false;
168193
Completion completion = (Completion) o;
169194

170-
return Objects.equals(prompt, completion.prompt) && Objects.equals(targetField, completion.targetField);
195+
return Objects.equals(prompt, completion.prompt)
196+
&& Objects.equals(targetField, completion.targetField)
197+
&& Objects.equals(rowLimit, completion.rowLimit);
171198
}
172199

173200
@Override
174201
public int hashCode() {
175-
return Objects.hash(super.hashCode(), prompt, targetField);
202+
return Objects.hash(super.hashCode(), prompt, targetField, rowLimit);
176203
}
177204
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.esql.plan.logical.inference;
99

10+
import org.elasticsearch.TransportVersion;
1011
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1112
import org.elasticsearch.common.io.stream.StreamInput;
1213
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -41,14 +42,25 @@ public class Rerank extends InferencePlan<Rerank> implements PostAnalysisVerific
4142

4243
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
4344
public static final String DEFAULT_INFERENCE_ID = ".rerank-v1-elasticsearch";
45+
public static final int DEFAULT_MAX_ROW_LIMIT = 1000;
46+
47+
private static final TransportVersion ESQL_RERANK_USAGE_LIMIT = TransportVersion.fromName("esql_rerank_usage_limit");
4448

4549
private final Attribute scoreAttribute;
4650
private final Expression queryText;
4751
private final List<Alias> rerankFields;
52+
private final Expression rowLimit;
4853
private List<Attribute> lazyOutput;
4954

50-
public Rerank(Source source, LogicalPlan child, Expression queryText, List<Alias> rerankFields, Attribute scoreAttribute) {
51-
this(source, child, Literal.keyword(Source.EMPTY, DEFAULT_INFERENCE_ID), queryText, rerankFields, scoreAttribute);
55+
public Rerank(
56+
Source source,
57+
LogicalPlan child,
58+
Expression queryText,
59+
List<Alias> rerankFields,
60+
Attribute scoreAttribute,
61+
Expression rowLimit
62+
) {
63+
this(source, child, Literal.keyword(Source.EMPTY, DEFAULT_INFERENCE_ID), queryText, rerankFields, scoreAttribute, rowLimit);
5264
}
5365

5466
public Rerank(
@@ -57,12 +69,14 @@ public Rerank(
5769
Expression inferenceId,
5870
Expression queryText,
5971
List<Alias> rerankFields,
60-
Attribute scoreAttribute
72+
Attribute scoreAttribute,
73+
Expression rowLimit
6174
) {
6275
super(source, child, inferenceId);
6376
this.queryText = queryText;
6477
this.rerankFields = rerankFields;
6578
this.scoreAttribute = scoreAttribute;
79+
this.rowLimit = rowLimit;
6680
}
6781

6882
public Rerank(StreamInput in) throws IOException {
@@ -72,7 +86,10 @@ public Rerank(StreamInput in) throws IOException {
7286
in.readNamedWriteable(Expression.class),
7387
in.readNamedWriteable(Expression.class),
7488
in.readCollectionAsList(Alias::new),
75-
in.readNamedWriteable(Attribute.class)
89+
in.readNamedWriteable(Attribute.class),
90+
in.getTransportVersion().supports(ESQL_RERANK_USAGE_LIMIT)
91+
? in.readNamedWriteable(Expression.class)
92+
: Literal.integer(Source.EMPTY, DEFAULT_MAX_ROW_LIMIT)
7693
);
7794
}
7895

@@ -82,6 +99,9 @@ public void writeTo(StreamOutput out) throws IOException {
8299
out.writeNamedWriteable(queryText);
83100
out.writeCollection(rerankFields());
84101
out.writeNamedWriteable(scoreAttribute);
102+
if (out.getTransportVersion().supports(ESQL_RERANK_USAGE_LIMIT)) {
103+
out.writeNamedWriteable(rowLimit);
104+
}
85105
}
86106

87107
public Expression queryText() {
@@ -96,6 +116,10 @@ public Attribute scoreAttribute() {
96116
return scoreAttribute;
97117
}
98118

119+
public Expression rowLimit() {
120+
return rowLimit;
121+
}
122+
99123
@Override
100124
public TaskType taskType() {
101125
return TaskType.RERANK;
@@ -106,23 +130,23 @@ public Rerank withInferenceId(Expression newInferenceId) {
106130
if (inferenceId().equals(newInferenceId)) {
107131
return this;
108132
}
109-
return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute);
133+
return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute, rowLimit);
110134
}
111135

112136
public Rerank withRerankFields(List<Alias> newRerankFields) {
113137
if (rerankFields.equals(newRerankFields)) {
114138
return this;
115139
}
116140

117-
return new Rerank(source(), child(), inferenceId(), queryText, newRerankFields, scoreAttribute);
141+
return new Rerank(source(), child(), inferenceId(), queryText, newRerankFields, scoreAttribute, rowLimit);
118142
}
119143

120144
public Rerank withScoreAttribute(Attribute newScoreAttribute) {
121145
if (scoreAttribute.equals(newScoreAttribute)) {
122146
return this;
123147
}
124148

125-
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, newScoreAttribute);
149+
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, newScoreAttribute, rowLimit);
126150
}
127151

128152
@Override
@@ -132,7 +156,7 @@ public String getWriteableName() {
132156

133157
@Override
134158
public UnaryPlan replaceChild(LogicalPlan newChild) {
135-
return new Rerank(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
159+
return new Rerank(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute, rowLimit);
136160
}
137161

138162
@Override
@@ -147,7 +171,7 @@ public List<Attribute> generatedAttributes() {
147171
@Override
148172
public Rerank withGeneratedNames(List<String> newNames) {
149173
checkNumberOfNewNames(newNames);
150-
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, this.renameScoreAttribute(newNames.get(0)));
174+
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, this.renameScoreAttribute(newNames.get(0)), rowLimit);
151175
}
152176

153177
private Attribute renameScoreAttribute(String newName) {
@@ -181,7 +205,7 @@ public boolean isFoldable() {
181205

182206
@Override
183207
protected NodeInfo<? extends LogicalPlan> info() {
184-
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
208+
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute, rowLimit);
185209
}
186210

187211
@Override
@@ -221,12 +245,13 @@ public boolean equals(Object o) {
221245
Rerank rerank = (Rerank) o;
222246
return Objects.equals(queryText, rerank.queryText)
223247
&& Objects.equals(rerankFields, rerank.rerankFields)
224-
&& Objects.equals(scoreAttribute, rerank.scoreAttribute);
248+
&& Objects.equals(scoreAttribute, rerank.scoreAttribute)
249+
&& Objects.equals(rowLimit, rerank.rowLimit);
225250
}
226251

227252
@Override
228253
public int hashCode() {
229-
return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
254+
return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute, rowLimit);
230255
}
231256

232257
@Override

0 commit comments

Comments
 (0)