Skip to content

Commit 7d986e6

Browse files
committed
Add support for chat_completion to the Completion physical plan.
1 parent 27ddca3 commit 7d986e6

File tree

6 files changed

+55
-15
lines changed

6 files changed

+55
-15
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ public void writeTo(StreamOutput out) throws IOException {
8686
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)) {
8787
out.writeOptional((output, taskType) -> output.writeString(taskType.toString()), taskType());
8888
}
89-
9089
out.writeNamedWriteable(prompt);
9190
out.writeNamedWriteable(targetField);
9291
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
package org.elasticsearch.xpack.esql.plan.physical.inference;
99

10+
import org.elasticsearch.TransportVersions;
1011
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1112
import org.elasticsearch.common.io.stream.StreamInput;
1213
import org.elasticsearch.common.io.stream.StreamOutput;
14+
import org.elasticsearch.inference.TaskType;
1315
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1416
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
1517
import org.elasticsearch.xpack.esql.core.expression.Expression;
@@ -37,8 +39,15 @@ public class CompletionExec extends InferenceExec {
3739
private final Attribute targetField;
3840
private List<Attribute> lazyOutput;
3941

40-
public CompletionExec(Source source, PhysicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
41-
super(source, child, inferenceId);
42+
public CompletionExec(
43+
Source source,
44+
PhysicalPlan child,
45+
Expression inferenceId,
46+
TaskType taskType,
47+
Expression prompt,
48+
Attribute targetField
49+
) {
50+
super(source, child, inferenceId, taskType);
4251
this.prompt = prompt;
4352
this.targetField = targetField;
4453
}
@@ -48,6 +57,9 @@ public CompletionExec(StreamInput in) throws IOException {
4857
Source.readFrom((PlanStreamInput) in),
4958
in.readNamedWriteable(PhysicalPlan.class),
5059
in.readNamedWriteable(Expression.class),
60+
in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)
61+
? TaskType.fromString(in.readString())
62+
: TaskType.COMPLETION,
5163
in.readNamedWriteable(Expression.class),
5264
in.readNamedWriteable(Attribute.class)
5365
);
@@ -61,6 +73,9 @@ public String getWriteableName() {
6173
@Override
6274
public void writeTo(StreamOutput out) throws IOException {
6375
super.writeTo(out);
76+
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)) {
77+
out.writeString(taskType().toString());
78+
}
6479
out.writeNamedWriteable(prompt);
6580
out.writeNamedWriteable(targetField);
6681
}
@@ -75,12 +90,12 @@ public Attribute targetField() {
7590

7691
@Override
7792
protected NodeInfo<? extends PhysicalPlan> info() {
78-
return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), prompt, targetField);
93+
return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), taskType(), prompt, targetField);
7994
}
8095

8196
@Override
8297
public UnaryExec replaceChild(PhysicalPlan newChild) {
83-
return new CompletionExec(source(), newChild, inferenceId(), prompt, targetField);
98+
return new CompletionExec(source(), newChild, inferenceId(), taskType(), prompt, targetField);
8499
}
85100

86101
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.plan.physical.inference;
99

1010
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.inference.TaskType;
1112
import org.elasticsearch.xpack.esql.core.expression.Expression;
1213
import org.elasticsearch.xpack.esql.core.tree.Source;
1314
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
@@ -18,16 +19,22 @@
1819

1920
public abstract class InferenceExec extends UnaryExec {
2021
private final Expression inferenceId;
22+
private final TaskType taskType;
2123

22-
protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId) {
24+
protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId, TaskType taskType) {
2325
super(source, child);
2426
this.inferenceId = inferenceId;
27+
this.taskType = taskType;
2528
}
2629

2730
public Expression inferenceId() {
2831
return inferenceId;
2932
}
3033

34+
public TaskType taskType() {
35+
return taskType;
36+
}
37+
3138
@Override
3239
public void writeTo(StreamOutput out) throws IOException {
3340
source().writeTo(out);
@@ -41,11 +48,11 @@ public boolean equals(Object o) {
4148
if (o == null || getClass() != o.getClass()) return false;
4249
if (super.equals(o) == false) return false;
4350
InferenceExec that = (InferenceExec) o;
44-
return inferenceId.equals(that.inferenceId);
51+
return inferenceId.equals(that.inferenceId) && taskType == that.taskType;
4552
}
4653

4754
@Override
4855
public int hashCode() {
49-
return Objects.hash(super.hashCode(), inferenceId());
56+
return Objects.hash(super.hashCode(), inferenceId(), taskType);
5057
}
5158
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.inference.TaskType;
1314
import org.elasticsearch.xpack.esql.core.expression.Alias;
1415
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1516
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
@@ -48,7 +49,7 @@ public RerankExec(
4849
List<Alias> rerankFields,
4950
Attribute scoreAttribute
5051
) {
51-
super(source, child, inferenceId);
52+
super(source, child, inferenceId, TaskType.RERANK);
5253
this.queryText = queryText;
5354
this.rerankFields = rerankFields;
5455
this.scoreAttribute = scoreAttribute;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,14 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) {
103103
}
104104

105105
if (p instanceof Completion completion) {
106-
return new CompletionExec(completion.source(), child, completion.inferenceId(), completion.prompt(), completion.targetField());
106+
return new CompletionExec(
107+
completion.source(),
108+
child,
109+
completion.inferenceId(),
110+
completion.taskType(),
111+
completion.prompt(),
112+
completion.targetField()
113+
);
107114
}
108115

109116
if (p instanceof Enrich enrich) {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77

88
package org.elasticsearch.xpack.esql.plan.physical.inference;
99

10+
import org.elasticsearch.inference.TaskType;
1011
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1112
import org.elasticsearch.xpack.esql.core.expression.Expression;
1213
import org.elasticsearch.xpack.esql.core.expression.Literal;
1314
import org.elasticsearch.xpack.esql.core.tree.Source;
1415
import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests;
16+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
1517
import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests;
1618
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
1719

@@ -20,7 +22,14 @@
2022
public class CompletionExecSerializationTests extends AbstractPhysicalPlanSerializationTests<CompletionExec> {
2123
@Override
2224
protected CompletionExec createTestInstance() {
23-
return new CompletionExec(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute());
25+
return new CompletionExec(
26+
randomSource(),
27+
randomChild(0),
28+
randomInferenceId(),
29+
randomFrom(Completion.SUPPORTED_TASK_TYPES),
30+
randomPrompt(),
31+
randomAttribute()
32+
);
2433
}
2534

2635
@Override
@@ -29,14 +38,16 @@ protected CompletionExec mutateInstance(CompletionExec instance) throws IOExcept
2938
Expression inferenceId = instance.inferenceId();
3039
Expression prompt = instance.prompt();
3140
Attribute targetField = instance.targetField();
41+
TaskType taskType = instance.taskType();
3242

33-
switch (between(0, 3)) {
43+
switch (between(0, 4)) {
3444
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
3545
case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId);
36-
case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
37-
case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
46+
case 2 -> taskType = randomValueOtherThan(taskType, () -> randomFrom(Completion.SUPPORTED_TASK_TYPES));
47+
case 3 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
48+
case 4 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
3849
}
39-
return new CompletionExec(instance.source(), child, inferenceId, prompt, targetField);
50+
return new CompletionExec(instance.source(), child, inferenceId, taskType, prompt, targetField);
4051
}
4152

4253
private Literal randomInferenceId() {

0 commit comments

Comments
 (0)