Skip to content

Commit efbd414

Browse files
authored
ES|QL completion command constant folding (#138112)
1 parent 6e29516 commit efbd414

File tree

10 files changed

+446
-6
lines changed

10 files changed

+446
-6
lines changed

docs/changelog/138112.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 138112
2+
summary: ES|QL completion command constant folding
3+
area: ES|QL
4+
type: enhancement
5+
issues:
6+
- 136863

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
7878
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
7979
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
80+
import org.elasticsearch.xpack.esql.expression.function.inference.CompletionFunction;
8081
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
8182
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
8283
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
@@ -1465,8 +1466,8 @@ private static class ResolveInference extends ParameterizedRule<LogicalPlan, Log
14651466

14661467
@Override
14671468
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
1468-
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
1469-
.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
1469+
return plan.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context))
1470+
.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
14701471
}
14711472

14721473
private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {
@@ -1493,6 +1494,28 @@ private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext
14931494
return plan.withInferenceResolutionError(inferenceId, error);
14941495
}
14951496

1497+
if (plan.isFoldable()) {
1498+
// Transform foldable InferencePlan to Eval with function call
1499+
return transformToEval(plan, inferenceId);
1500+
}
1501+
1502+
return plan;
1503+
}
1504+
1505+
/**
1506+
* Transforms a foldable InferencePlan to an Eval with the appropriate function call.
1507+
*/
1508+
private LogicalPlan transformToEval(InferencePlan<?> plan, String inferenceId) {
1509+
Expression inferenceIdLiteral = Literal.keyword(plan.inferenceId().source(), inferenceId);
1510+
Source source = plan.source();
1511+
LogicalPlan child = plan.child();
1512+
1513+
if (plan instanceof Completion completion) {
1514+
CompletionFunction completionFunction = new CompletionFunction(source, completion.prompt(), inferenceIdLiteral);
1515+
Alias alias = new Alias(source, completion.targetField().name(), completionFunction, completion.targetField().id());
1516+
return new Eval(source, child, List.of(alias));
1517+
}
1518+
14961519
return plan;
14971520
}
14981521

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.xpack.esql.core.expression.Expression;
13+
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
14+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
15+
import org.elasticsearch.xpack.esql.core.tree.Source;
16+
import org.elasticsearch.xpack.esql.core.type.DataType;
17+
18+
import java.io.IOException;
19+
import java.util.List;
20+
import java.util.Objects;
21+
22+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
23+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
24+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
25+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
26+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
27+
28+
/**
29+
* COMPLETION function generates text completions from a prompt using an inference endpoint.
30+
* <p>
31+
* This function is an internal optimization primitive used exclusively for constant folding of
32+
* {@code COMPLETION} commands during the analysis phase. It should never be registered in the
33+
* function registry or exposed to users, as ESQL does not currently support async functions
34+
* in the function registry.
35+
* <p>
36+
* When a {@code COMPLETION} command has a foldable prompt (e.g., a literal or foldable expression),
37+
* the analyzer transforms it into an {@code EVAL} node with a {@code CompletionFunction} expression:
38+
* <pre>{@code
39+
* FROM books
40+
* | COMPLETION "Translate this text" WITH { "inference_id": "my-model" }
41+
* }</pre>
42+
* is internally rewritten into:
43+
* <pre>{@code
44+
* FROM books
45+
* | EVAL completion = COMPLETION("Translate this text", "my-model")
46+
* }</pre>
47+
* The pre-optimizer then evaluates this function using {@code InferenceFunctionEvaluator} and
48+
* replaces it with a literal result.
49+
*/
50+
public class CompletionFunction extends InferenceFunction<CompletionFunction> {
51+
52+
private final Expression inferenceId;
53+
private final Expression prompt;
54+
55+
public CompletionFunction(Source source, Expression prompt, Expression inferenceId) {
56+
super(source, List.of(prompt, inferenceId));
57+
this.inferenceId = inferenceId;
58+
this.prompt = prompt;
59+
}
60+
61+
@Override
62+
public void writeTo(StreamOutput out) throws IOException {
63+
throw new UnsupportedOperationException("doesn't escape the node");
64+
}
65+
66+
@Override
67+
public String getWriteableName() {
68+
throw new UnsupportedOperationException("doesn't escape the node");
69+
}
70+
71+
public Expression prompt() {
72+
return prompt;
73+
}
74+
75+
@Override
76+
public Expression inferenceId() {
77+
return inferenceId;
78+
}
79+
80+
@Override
81+
public boolean foldable() {
82+
return inferenceId.foldable() && prompt.foldable();
83+
}
84+
85+
@Override
86+
public DataType dataType() {
87+
return prompt.dataType() == DataType.NULL ? DataType.NULL : DataType.KEYWORD;
88+
}
89+
90+
@Override
91+
protected TypeResolution resolveType() {
92+
if (childrenResolved() == false) {
93+
return new TypeResolution("Unresolved children");
94+
}
95+
96+
TypeResolution promptResolution = isNotNull(prompt, sourceText(), FIRST).and(isFoldable(prompt, sourceText(), FIRST))
97+
.and(isType(prompt, DataType::isString, sourceText(), FIRST, "string"));
98+
99+
if (promptResolution.unresolved()) {
100+
return promptResolution;
101+
}
102+
103+
TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and(
104+
isType(inferenceId, DataType.KEYWORD::equals, sourceText(), SECOND, "string")
105+
).and(isFoldable(inferenceId, sourceText(), SECOND));
106+
107+
if (inferenceIdResolution.unresolved()) {
108+
return inferenceIdResolution;
109+
}
110+
111+
return TypeResolution.TYPE_RESOLVED;
112+
}
113+
114+
@Override
115+
public TaskType taskType() {
116+
return TaskType.COMPLETION;
117+
}
118+
119+
@Override
120+
public CompletionFunction withInferenceResolutionError(String inferenceId, String error) {
121+
return new CompletionFunction(source(), prompt, new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
122+
}
123+
124+
@Override
125+
public Expression replaceChildren(List<Expression> newChildren) {
126+
return new CompletionFunction(source(), newChildren.get(0), newChildren.get(1));
127+
}
128+
129+
@Override
130+
protected NodeInfo<? extends Expression> info() {
131+
return NodeInfo.create(this, CompletionFunction::new, prompt, inferenceId);
132+
}
133+
134+
@Override
135+
public String toString() {
136+
return "COMPLETION(" + prompt + ", " + inferenceId + ")";
137+
}
138+
139+
@Override
140+
public boolean equals(Object o) {
141+
if (o == null || getClass() != o.getClass()) return false;
142+
if (super.equals(o) == false) return false;
143+
CompletionFunction completionFunction = (CompletionFunction) o;
144+
return Objects.equals(inferenceId, completionFunction.inferenceId) && Objects.equals(prompt, completionFunction.prompt);
145+
}
146+
147+
@Override
148+
public int hashCode() {
149+
return Objects.hash(super.hashCode(), inferenceId, prompt);
150+
}
151+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
import org.elasticsearch.xpack.esql.core.expression.Literal;
2828
import org.elasticsearch.xpack.esql.core.type.DataType;
2929
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
30+
import org.elasticsearch.xpack.esql.expression.function.inference.CompletionFunction;
3031
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
3132
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
33+
import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator;
3234
import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator;
3335

3436
import java.util.List;
@@ -45,7 +47,6 @@ public static InferenceFunctionEvaluator.Factory factory() {
4547
return FACTORY;
4648
}
4749

48-
private final FoldContext foldContext;
4950
private final InferenceOperatorProvider inferenceOperatorProvider;
5051

5152
/**
@@ -56,7 +57,6 @@ public static InferenceFunctionEvaluator.Factory factory() {
5657
* @param inferenceOperatorProvider custom provider for creating inference operators
5758
*/
5859
InferenceFunctionEvaluator(FoldContext foldContext, InferenceOperatorProvider inferenceOperatorProvider) {
59-
this.foldContext = foldContext;
6060
this.inferenceOperatorProvider = inferenceOperatorProvider;
6161
}
6262

@@ -213,6 +213,11 @@ private InferenceOperatorProvider createInferenceOperatorProvider(FoldContext fo
213213
inferenceId(inferenceFunction, foldContext),
214214
expressionEvaluatorFactory(textEmbedding.inputText(), foldContext)
215215
);
216+
case CompletionFunction completion -> new CompletionOperator.Factory(
217+
inferenceService,
218+
inferenceId(inferenceFunction, foldContext),
219+
expressionEvaluatorFactory(completion.prompt(), foldContext)
220+
);
216221
default -> throw new IllegalArgumentException("Unknown inference function: " + inferenceFunction.getClass().getName());
217222
};
218223

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ public boolean expressionsResolved() {
143143
return super.expressionsResolved() && prompt.resolved() && targetField.resolved();
144144
}
145145

146+
@Override
147+
public boolean isFoldable() {
148+
return prompt.foldable();
149+
}
150+
146151
@Override
147152
public void postAnalysisVerification(Failures failures) {
148153
if (prompt.resolved() && DataType.isString(prompt.dataType()) == false) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,10 @@ public PlanType withInferenceResolutionError(String inferenceId, String error) {
8181
public List<String> validOptionNames() {
8282
return VALID_INFERENCE_OPTION_NAMES;
8383
}
84+
85+
/**
86+
* Checks if this InferencePlan is foldable (all input expressions are foldable).
87+
* A plan is foldable if all its input expressions can be evaluated statically.
88+
*/
89+
public abstract boolean isFoldable();
8490
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ public boolean expressionsResolved() {
174174
return super.expressionsResolved() && queryText.resolved() && Resolvables.resolved(rerankFields) && scoreAttribute.resolved();
175175
}
176176

177+
@Override
178+
public boolean isFoldable() {
179+
return false;
180+
}
181+
177182
@Override
178183
protected NodeInfo<? extends LogicalPlan> info() {
179184
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
6060
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
6161
import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket;
62+
import org.elasticsearch.xpack.esql.expression.function.inference.CompletionFunction;
6263
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
6364
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos;
6465
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime;
@@ -4291,6 +4292,67 @@ public void testResolveCompletionOutputFieldOverwriteInputField() {
42914292
assertThat(getAttributeByName(esRelation.output(), "description"), not(equalTo(completion.targetField())));
42924293
}
42934294

4295+
public void testFoldableCompletionTransformedToEval() {
4296+
// Test that a foldable Completion plan (with literal prompt) is transformed to Eval with CompletionFunction
4297+
LogicalPlan plan = analyze("""
4298+
FROM books METADATA _score
4299+
| COMPLETION "Translate this text in French" WITH { "inference_id" : "completion-inference-id" }
4300+
""", "mapping-books.json");
4301+
4302+
Eval eval = as(as(plan, Limit.class).child(), Eval.class);
4303+
assertThat(eval.fields().size(), equalTo(1));
4304+
4305+
Alias alias = eval.fields().get(0);
4306+
assertThat(alias.name(), equalTo("completion"));
4307+
assertThat(alias.child(), instanceOf(CompletionFunction.class));
4308+
4309+
CompletionFunction completionFunction = as(alias.child(), CompletionFunction.class);
4310+
assertThat(completionFunction.prompt(), equalTo(string("Translate this text in French")));
4311+
assertThat(completionFunction.inferenceId(), equalTo(string("completion-inference-id")));
4312+
assertThat(completionFunction.taskType(), equalTo(org.elasticsearch.inference.TaskType.COMPLETION));
4313+
}
4314+
4315+
public void testFoldableCompletionWithCustomTargetFieldTransformedToEval() {
4316+
// Test that a foldable Completion plan with custom target field is transformed correctly
4317+
LogicalPlan plan = analyze("""
4318+
FROM books METADATA _score
4319+
| COMPLETION translation = "Translate this text" WITH { "inference_id" : "completion-inference-id" }
4320+
""", "mapping-books.json");
4321+
4322+
Eval eval = as(as(plan, Limit.class).child(), Eval.class);
4323+
assertThat(eval.fields().size(), equalTo(1));
4324+
4325+
Alias alias = eval.fields().get(0);
4326+
assertThat(alias.name(), equalTo("translation"));
4327+
assertThat(alias.child(), instanceOf(CompletionFunction.class));
4328+
4329+
CompletionFunction completionFunction = as(alias.child(), CompletionFunction.class);
4330+
assertThat(completionFunction.prompt(), equalTo(string("Translate this text")));
4331+
assertThat(completionFunction.inferenceId(), equalTo(string("completion-inference-id")));
4332+
}
4333+
4334+
public void testFoldableCompletionWithFoldableExpressionTransformedToEval() {
4335+
// Test that a foldable Completion plan with a foldable expression (not just a literal) is transformed correctly
4336+
// Using CONCAT with all literal arguments to ensure it's foldable during analysis
4337+
LogicalPlan plan = analyze("""
4338+
FROM books METADATA _score
4339+
| COMPLETION CONCAT("Translate", " ", "this text") WITH { "inference_id" : "completion-inference-id" }
4340+
""", "mapping-books.json");
4341+
4342+
Eval eval = as(as(plan, Limit.class).child(), Eval.class);
4343+
assertThat(eval.fields().size(), equalTo(1));
4344+
4345+
Alias alias = eval.fields().get(0);
4346+
assertThat(alias.name(), equalTo("completion"));
4347+
assertThat(alias.child(), instanceOf(CompletionFunction.class));
4348+
4349+
CompletionFunction completionFunction = as(alias.child(), CompletionFunction.class);
4350+
// The prompt should be a Concat expression that is foldable (all arguments are literals)
4351+
assertThat(completionFunction.prompt(), instanceOf(Concat.class));
4352+
assertThat(completionFunction.prompt().foldable(), equalTo(true));
4353+
assertThat(completionFunction.inferenceId(), equalTo(string("completion-inference-id")));
4354+
}
4355+
42944356
public void testResolveGroupingsBeforeResolvingImplicitReferencesToGroupings() {
42954357
var plan = analyze("""
42964358
FROM test

0 commit comments

Comments
 (0)