Skip to content

Commit 56c1a1e

Browse files
author
afoucret
committed
Add limit to Completion and Rerank command.
1 parent c95dcea commit 56c1a1e

File tree

12 files changed

+173
-35
lines changed

12 files changed

+173
-35
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9231000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ml_groq_inference_service,9230000
1+
esql_inference_usage_limit,9231000

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ private LogicalPlan resolveCompletion(Completion p, List<Attribute> childrenOutp
605605
prompt = prompt.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
606606
}
607607

608-
return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField);
608+
return new Completion(p.source(), p.child(), p.inferenceId(), p.rowLimit(), prompt, targetField);
609609
}
610610

611611
private LogicalPlan resolveMvExpand(MvExpand p, List<Attribute> childrenOutput) {
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Expression;
11+
import org.elasticsearch.xpack.esql.core.expression.Literal;
12+
import org.elasticsearch.xpack.esql.core.tree.Source;
13+
14+
/**
15+
* Interface for logical plans that enforce a row limit.
16+
* Plans implementing this interface have a maximum number of rows they can handle,
17+
* which may be enforced during plan transformation or execution.
18+
*
19+
* <p>
20+
* Practically it means that a LIMIT to the plan children.
21+
*/
22+
public interface RowLimited extends SurrogateLogicalPlan {
23+
/**
24+
* Returns the maximum number of rows this plan can produce.
25+
*/
26+
int maxRows();
27+
28+
/**
29+
* Sets the maximum number of rows this plan can produce
30+
*/
31+
default RowLimited withMaxRows(int maxRows) {
32+
return withMaxRows(Literal.integer(Source.EMPTY, maxRows));
33+
}
34+
35+
/**
36+
* Sets the maximum number of rows this plan can produce
37+
*/
38+
RowLimited withMaxRows(Expression maxRows);
39+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.xpack.esql.core.type.DataType;
2525
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2626
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
27+
import org.elasticsearch.xpack.esql.plan.logical.RowLimited;
2728

2829
import java.io.IOException;
2930
import java.util.List;
@@ -36,6 +37,7 @@
3637
public class Completion extends InferencePlan<Completion> implements TelemetryAware, PostAnalysisVerificationAware {
3738

3839
public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion";
40+
public static final int DEFAULT_MAX_ROW_LIMIT = 100;
3941

4042
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
4143
LogicalPlan.class,
@@ -47,11 +49,18 @@ public class Completion extends InferencePlan<Completion> implements TelemetryAw
4749
private List<Attribute> lazyOutput;
4850

4951
public Completion(Source source, LogicalPlan p, Expression prompt, Attribute targetField) {
50-
this(source, p, Literal.keyword(Source.EMPTY, DEFAULT_OUTPUT_FIELD_NAME), prompt, targetField);
51-
}
52-
53-
public Completion(Source source, LogicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
54-
super(source, child, inferenceId);
52+
this(source, p, Literal.NULL, Literal.integer(Source.EMPTY, DEFAULT_MAX_ROW_LIMIT), prompt, targetField);
53+
}
54+
55+
public Completion(
56+
Source source,
57+
LogicalPlan child,
58+
Expression inferenceId,
59+
Expression rowLimit,
60+
Expression prompt,
61+
Attribute targetField
62+
) {
63+
super(source, child, inferenceId, rowLimit);
5564
this.prompt = prompt;
5665
this.targetField = targetField;
5766
}
@@ -61,6 +70,9 @@ public Completion(StreamInput in) throws IOException {
6170
Source.readFrom((PlanStreamInput) in),
6271
in.readNamedWriteable(LogicalPlan.class),
6372
in.readNamedWriteable(Expression.class),
73+
in.getTransportVersion().supports(ESQL_INFERENCE_USAGE_LIMIT)
74+
? in.readNamedWriteable(Expression.class)
75+
: Literal.integer(Source.EMPTY, DEFAULT_MAX_ROW_LIMIT),
6476
in.readNamedWriteable(Expression.class),
6577
in.readNamedWriteable(Attribute.class)
6678
);
@@ -81,18 +93,23 @@ public Attribute targetField() {
8193
return targetField;
8294
}
8395

96+
@Override
97+
public RowLimited withMaxRows(Expression rowLimit) {
98+
return new Completion(source(), child(), inferenceId(), rowLimit, prompt, targetField);
99+
}
100+
84101
@Override
85102
public Completion withInferenceId(Expression newInferenceId) {
86103
if (inferenceId().equals(newInferenceId)) {
87104
return this;
88105
}
89106

90-
return new Completion(source(), child(), newInferenceId, prompt, targetField);
107+
return new Completion(source(), child(), newInferenceId, rowLimit(), prompt, targetField);
91108
}
92109

93110
@Override
94111
public Completion replaceChild(LogicalPlan newChild) {
95-
return new Completion(source(), newChild, inferenceId(), prompt, targetField);
112+
return new Completion(source(), newChild, inferenceId(), rowLimit(), prompt, targetField);
96113
}
97114

98115
@Override
@@ -122,7 +139,7 @@ public List<Attribute> generatedAttributes() {
122139
@Override
123140
public Completion withGeneratedNames(List<String> newNames) {
124141
checkNumberOfNewNames(newNames);
125-
return new Completion(source(), child(), inferenceId(), prompt, this.renameTargetField(newNames.get(0)));
142+
return new Completion(source(), child(), inferenceId(), rowLimit(), prompt, this.renameTargetField(newNames.get(0)));
126143
}
127144

128145
private Attribute renameTargetField(String newName) {
@@ -157,7 +174,7 @@ public void postAnalysisVerification(Failures failures) {
157174

158175
@Override
159176
protected NodeInfo<? extends LogicalPlan> info() {
160-
return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField);
177+
return NodeInfo.create(this, Completion::new, child(), inferenceId(), rowLimit(), prompt, targetField);
161178
}
162179

163180
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77

88
package org.elasticsearch.xpack.esql.plan.logical.inference;
99

10+
import org.elasticsearch.TransportVersion;
1011
import org.elasticsearch.common.io.stream.StreamOutput;
1112
import org.elasticsearch.inference.TaskType;
1213
import org.elasticsearch.xpack.esql.core.expression.Expression;
1314
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
1415
import org.elasticsearch.xpack.esql.core.tree.Source;
16+
import org.elasticsearch.xpack.esql.expression.Foldables;
1517
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
1618
import org.elasticsearch.xpack.esql.plan.logical.ExecutesOn;
1719
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
20+
import org.elasticsearch.xpack.esql.plan.logical.RowLimited;
1821
import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
1922
import org.elasticsearch.xpack.esql.plan.logical.Streaming;
2023
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
@@ -28,32 +31,55 @@ public abstract class InferencePlan<PlanType extends InferencePlan<PlanType>> ex
2831
Streaming,
2932
SortAgnostic,
3033
GeneratingPlan<InferencePlan<PlanType>>,
31-
ExecutesOn.Coordinator {
34+
ExecutesOn.Coordinator,
35+
RowLimited {
36+
37+
protected static final TransportVersion ESQL_INFERENCE_USAGE_LIMIT = TransportVersion.fromName("esql_inference_usage_limit");
3238

3339
public static final String INFERENCE_ID_OPTION_NAME = "inference_id";
3440
public static final List<String> VALID_INFERENCE_OPTION_NAMES = List.of(INFERENCE_ID_OPTION_NAME);
3541

3642
private final Expression inferenceId;
43+
private final Expression rowLimit;
3744

38-
protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId) {
45+
protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId, Expression rowLimit) {
3946
super(source, child);
4047
this.inferenceId = inferenceId;
48+
this.rowLimit = rowLimit;
4149
}
4250

4351
@Override
4452
public void writeTo(StreamOutput out) throws IOException {
4553
source().writeTo(out);
4654
out.writeNamedWriteable(child());
4755
out.writeNamedWriteable(inferenceId());
56+
57+
if (out.getTransportVersion().supports(ESQL_INFERENCE_USAGE_LIMIT)) {
58+
out.writeNamedWriteable(rowLimit());
59+
}
4860
}
4961

5062
public Expression inferenceId() {
5163
return inferenceId;
5264
}
5365

66+
public Expression rowLimit() {
67+
return rowLimit;
68+
}
69+
70+
@Override
71+
public int maxRows() {
72+
return Foldables.intValueOf(rowLimit, rowLimit.sourceText(), "row limit");
73+
}
74+
75+
@Override
76+
public LogicalPlan surrogate() {
77+
return this;
78+
}
79+
5480
@Override
5581
public boolean expressionsResolved() {
56-
return inferenceId.resolved();
82+
return inferenceId.resolved() && rowLimit.resolved();
5783
}
5884

5985
@Override
@@ -62,12 +88,12 @@ public boolean equals(Object o) {
6288
if (o == null || getClass() != o.getClass()) return false;
6389
if (super.equals(o) == false) return false;
6490
InferencePlan<?> other = (InferencePlan<?>) o;
65-
return Objects.equals(inferenceId(), other.inferenceId());
91+
return Objects.equals(inferenceId, other.inferenceId) && Objects.equals(rowLimit, other.rowLimit);
6692
}
6793

6894
@Override
6995
public int hashCode() {
70-
return Objects.hash(super.hashCode(), inferenceId());
96+
return Objects.hash(super.hashCode(), inferenceId(), rowLimit());
7197
}
7298

7399
public abstract TaskType taskType();

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2828
import org.elasticsearch.xpack.esql.plan.logical.Eval;
2929
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
30+
import org.elasticsearch.xpack.esql.plan.logical.RowLimited;
3031
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
3132

3233
import java.io.IOException;
@@ -41,25 +42,35 @@ public class Rerank extends InferencePlan<Rerank> implements PostAnalysisVerific
4142

4243
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
4344
public static final String DEFAULT_INFERENCE_ID = ".rerank-v1-elasticsearch";
45+
public static final int DEFAULT_MAX_ROW_LIMIT = 1000;
4446

4547
private final Attribute scoreAttribute;
4648
private final Expression queryText;
4749
private final List<Alias> rerankFields;
4850
private List<Attribute> lazyOutput;
4951

5052
public Rerank(Source source, LogicalPlan child, Expression queryText, List<Alias> rerankFields, Attribute scoreAttribute) {
51-
this(source, child, Literal.keyword(Source.EMPTY, DEFAULT_INFERENCE_ID), queryText, rerankFields, scoreAttribute);
53+
this(
54+
source,
55+
child,
56+
Literal.keyword(Source.EMPTY, DEFAULT_INFERENCE_ID),
57+
Literal.integer(Source.EMPTY, DEFAULT_MAX_ROW_LIMIT),
58+
queryText,
59+
rerankFields,
60+
scoreAttribute
61+
);
5262
}
5363

5464
public Rerank(
5565
Source source,
5666
LogicalPlan child,
5767
Expression inferenceId,
68+
Expression rowLimit,
5869
Expression queryText,
5970
List<Alias> rerankFields,
6071
Attribute scoreAttribute
6172
) {
62-
super(source, child, inferenceId);
73+
super(source, child, inferenceId, rowLimit);
6374
this.queryText = queryText;
6475
this.rerankFields = rerankFields;
6576
this.scoreAttribute = scoreAttribute;
@@ -70,6 +81,9 @@ public Rerank(StreamInput in) throws IOException {
7081
Source.readFrom((PlanStreamInput) in),
7182
in.readNamedWriteable(LogicalPlan.class),
7283
in.readNamedWriteable(Expression.class),
84+
in.getTransportVersion().supports(ESQL_INFERENCE_USAGE_LIMIT)
85+
? in.readNamedWriteable(Expression.class)
86+
: Literal.integer(Source.EMPTY, DEFAULT_MAX_ROW_LIMIT),
7387
in.readNamedWriteable(Expression.class),
7488
in.readCollectionAsList(Alias::new),
7589
in.readNamedWriteable(Attribute.class)
@@ -101,28 +115,33 @@ public TaskType taskType() {
101115
return TaskType.RERANK;
102116
}
103117

118+
@Override
119+
public RowLimited withMaxRows(Expression rowLimit) {
120+
return new Rerank(source(), child(), inferenceId(), rowLimit(), queryText, rerankFields, scoreAttribute);
121+
}
122+
104123
@Override
105124
public Rerank withInferenceId(Expression newInferenceId) {
106125
if (inferenceId().equals(newInferenceId)) {
107126
return this;
108127
}
109-
return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute);
128+
return new Rerank(source(), child(), newInferenceId, rowLimit(), queryText, rerankFields, scoreAttribute);
110129
}
111130

112131
public Rerank withRerankFields(List<Alias> newRerankFields) {
113132
if (rerankFields.equals(newRerankFields)) {
114133
return this;
115134
}
116135

117-
return new Rerank(source(), child(), inferenceId(), queryText, newRerankFields, scoreAttribute);
136+
return new Rerank(source(), child(), inferenceId(), rowLimit(), queryText, newRerankFields, scoreAttribute);
118137
}
119138

120139
public Rerank withScoreAttribute(Attribute newScoreAttribute) {
121140
if (scoreAttribute.equals(newScoreAttribute)) {
122141
return this;
123142
}
124143

125-
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, newScoreAttribute);
144+
return new Rerank(source(), child(), inferenceId(), rowLimit(), queryText, rerankFields, newScoreAttribute);
126145
}
127146

128147
@Override
@@ -132,7 +151,7 @@ public String getWriteableName() {
132151

133152
@Override
134153
public UnaryPlan replaceChild(LogicalPlan newChild) {
135-
return new Rerank(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
154+
return new Rerank(source(), newChild, inferenceId(), rowLimit(), queryText, rerankFields, scoreAttribute);
136155
}
137156

138157
@Override
@@ -147,7 +166,15 @@ public List<Attribute> generatedAttributes() {
147166
@Override
148167
public Rerank withGeneratedNames(List<String> newNames) {
149168
checkNumberOfNewNames(newNames);
150-
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, this.renameScoreAttribute(newNames.get(0)));
169+
return new Rerank(
170+
source(),
171+
child(),
172+
inferenceId(),
173+
rowLimit(),
174+
queryText,
175+
rerankFields,
176+
this.renameScoreAttribute(newNames.get(0))
177+
);
151178
}
152179

153180
private Attribute renameScoreAttribute(String newName) {
@@ -181,7 +208,7 @@ public boolean isFoldable() {
181208

182209
@Override
183210
protected NodeInfo<? extends LogicalPlan> info() {
184-
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
211+
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), rowLimit(), queryText, rerankFields, scoreAttribute);
185212
}
186213

187214
@Override

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5383,6 +5383,7 @@ record PushdownShadowingGeneratingPlanTestCase(
53835383
EMPTY,
53845384
plan,
53855385
randomLiteral(TEXT),
5386+
randomLiteral(INTEGER),
53865387
new Concat(EMPTY, randomLiteral(TEXT), List.of(attr)),
53875388
new ReferenceAttribute(EMPTY, "y", KEYWORD)
53885389
),
@@ -5394,6 +5395,7 @@ record PushdownShadowingGeneratingPlanTestCase(
53945395
EMPTY,
53955396
plan,
53965397
randomLiteral(TEXT),
5398+
randomLiteral(INTEGER),
53975399
randomLiteral(TEXT),
53985400
List.of(new Alias(EMPTY, attr.name(), attr)),
53995401
new ReferenceAttribute(EMPTY, "y", KEYWORD)

0 commit comments

Comments
 (0)