Skip to content

Commit 350ea7b

Browse files
committed
[ES|QL] COMPLETION command physical plan (elastic#126766)
1 parent d98b76a commit 350ea7b

File tree

9 files changed

+181
-22
lines changed

9 files changed

+181
-22
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
5050
import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec;
5151
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
52+
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
5253
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
5354

5455
import java.util.ArrayList;
@@ -91,6 +92,7 @@ public static List<NamedWriteableRegistry.Entry> logical() {
9192
public static List<NamedWriteableRegistry.Entry> physical() {
9293
return List.of(
9394
AggregateExec.ENTRY,
95+
CompletionExec.ENTRY,
9496
DissectExec.ENTRY,
9597
EnrichExec.ENTRY,
9698
EsQueryExec.ENTRY,

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId
2929

3030
@Override
3131
public void writeTo(StreamOutput out) throws IOException {
32-
Source.EMPTY.writeTo(out);
32+
source().writeTo(out);
3333
out.writeNamedWriteable(child());
3434
out.writeNamedWriteable(inferenceId());
3535
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.physical.inference;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
14+
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
15+
import org.elasticsearch.xpack.esql.core.expression.Expression;
16+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
17+
import org.elasticsearch.xpack.esql.core.tree.Source;
18+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
19+
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
20+
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
21+
22+
import java.io.IOException;
23+
import java.util.List;
24+
import java.util.Objects;
25+
26+
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
27+
28+
public class CompletionExec extends InferenceExec {
29+
30+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
31+
PhysicalPlan.class,
32+
"CompletionExec",
33+
CompletionExec::new
34+
);
35+
36+
private final Expression prompt;
37+
private final Attribute targetField;
38+
private List<Attribute> lazyOutput;
39+
40+
public CompletionExec(Source source, PhysicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
41+
super(source, child, inferenceId);
42+
this.prompt = prompt;
43+
this.targetField = targetField;
44+
}
45+
46+
public CompletionExec(StreamInput in) throws IOException {
47+
this(
48+
Source.readFrom((PlanStreamInput) in),
49+
in.readNamedWriteable(PhysicalPlan.class),
50+
in.readNamedWriteable(Expression.class),
51+
in.readNamedWriteable(Expression.class),
52+
in.readNamedWriteable(Attribute.class)
53+
);
54+
}
55+
56+
@Override
57+
public String getWriteableName() {
58+
return ENTRY.name;
59+
}
60+
61+
@Override
62+
public void writeTo(StreamOutput out) throws IOException {
63+
super.writeTo(out);
64+
out.writeNamedWriteable(prompt);
65+
out.writeNamedWriteable(targetField);
66+
}
67+
68+
public Expression prompt() {
69+
return prompt;
70+
}
71+
72+
public Attribute targetField() {
73+
return targetField;
74+
}
75+
76+
@Override
77+
protected NodeInfo<? extends PhysicalPlan> info() {
78+
return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), prompt, targetField);
79+
}
80+
81+
@Override
82+
public UnaryExec replaceChild(PhysicalPlan newChild) {
83+
return new CompletionExec(source(), newChild, inferenceId(), prompt, targetField);
84+
}
85+
86+
@Override
87+
public List<Attribute> output() {
88+
if (lazyOutput == null) {
89+
lazyOutput = mergeOutputAttributes(List.of(targetField), child().output());
90+
}
91+
92+
return lazyOutput;
93+
}
94+
95+
@Override
96+
protected AttributeSet computeReferences() {
97+
return prompt.references();
98+
}
99+
100+
@Override
101+
public boolean equals(Object o) {
102+
if (this == o) return true;
103+
if (o == null || getClass() != o.getClass()) return false;
104+
if (super.equals(o) == false) return false;
105+
CompletionExec completion = (CompletionExec) o;
106+
107+
return Objects.equals(prompt, completion.prompt) && Objects.equals(targetField, completion.targetField);
108+
}
109+
110+
@Override
111+
public int hashCode() {
112+
return Objects.hash(super.hashCode(), prompt, targetField);
113+
}
114+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public Expression inferenceId() {
3030

3131
@Override
3232
public void writeTo(StreamOutput out) throws IOException {
33-
Source.EMPTY.writeTo(out);
33+
source().writeTo(out);
3434
out.writeNamedWriteable(child());
3535
out.writeNamedWriteable(inferenceId());
3636
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
2525
import org.elasticsearch.xpack.esql.plan.logical.Project;
2626
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
27+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
2728
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
2829
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
2930
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
@@ -39,6 +40,7 @@
3940
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
4041
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
4142
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
43+
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
4244
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
4345
import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders;
4446

@@ -95,6 +97,10 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) {
9597
);
9698
}
9799

100+
if (p instanceof Completion completion) {
101+
return new CompletionExec(completion.source(), child, completion.inferenceId(), completion.prompt(), completion.targetField());
102+
}
103+
98104
if (p instanceof Enrich enrich) {
99105
return new EnrichExec(
100106
enrich.source(),

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import org.elasticsearch.xpack.esql.core.expression.Literal;
1313
import org.elasticsearch.xpack.esql.core.tree.Source;
1414
import org.elasticsearch.xpack.esql.core.type.DataType;
15-
import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests;
15+
import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests;
1616
import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests;
1717
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1818

@@ -22,9 +22,7 @@ public class CompletionSerializationTests extends AbstractLogicalPlanSerializati
2222

2323
@Override
2424
protected Completion createTestInstance() {
25-
Source source = randomSource();
26-
LogicalPlan child = randomChild(0);
27-
return new Completion(source, child, randomInferenceId(), randomPrompt(), randomAttribute());
25+
return new Completion(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute());
2826
}
2927

3028
@Override
@@ -43,11 +41,6 @@ protected Completion mutateInstance(Completion instance) throws IOException {
4341
return new Completion(instance.source(), child, inferenceId, prompt, targetField);
4442
}
4543

46-
@Override
47-
protected boolean alwaysEmptySource() {
48-
return true;
49-
}
50-
5144
private Literal randomInferenceId() {
5245
return new Literal(Source.EMPTY, randomIdentifier(), DataType.KEYWORD);
5346
}
@@ -57,6 +50,6 @@ private Expression randomPrompt() {
5750
}
5851

5952
private Attribute randomAttribute() {
60-
return FieldAttributeTests.createFieldAttribute(3, randomBoolean());
53+
return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
6154
}
6255
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/RerankSerializationTests.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ protected Rerank mutateInstance(Rerank instance) throws IOException {
4747
return new Rerank(instance.source(), child, inferenceId, queryText, fields, instance.scoreAttribute());
4848
}
4949

50-
@Override
51-
protected boolean alwaysEmptySource() {
52-
return true;
53-
}
54-
5550
private List<Alias> randomFields() {
5651
return randomList(0, 10, AliasTests::randomAlias);
5752
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.physical.inference;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.expression.Literal;
13+
import org.elasticsearch.xpack.esql.core.tree.Source;
14+
import org.elasticsearch.xpack.esql.core.type.DataType;
15+
import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests;
16+
import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests;
17+
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
18+
19+
import java.io.IOException;
20+
21+
public class CompletionExecSerializationTests extends AbstractPhysicalPlanSerializationTests<CompletionExec> {
22+
@Override
23+
protected CompletionExec createTestInstance() {
24+
return new CompletionExec(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute());
25+
}
26+
27+
@Override
28+
protected CompletionExec mutateInstance(CompletionExec instance) throws IOException {
29+
PhysicalPlan child = instance.child();
30+
Expression inferenceId = instance.inferenceId();
31+
Expression prompt = instance.prompt();
32+
Attribute targetField = instance.targetField();
33+
34+
switch (between(0, 3)) {
35+
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
36+
case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId);
37+
case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
38+
case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
39+
}
40+
return new CompletionExec(instance.source(), child, inferenceId, prompt, targetField);
41+
}
42+
43+
private Literal randomInferenceId() {
44+
return new Literal(Source.EMPTY, randomIdentifier(), DataType.KEYWORD);
45+
}
46+
47+
private Expression randomPrompt() {
48+
return randomBoolean() ? new Literal(Source.EMPTY, randomIdentifier(), DataType.KEYWORD) : randomAttribute();
49+
}
50+
51+
private Attribute randomAttribute() {
52+
return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
53+
}
54+
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExecSerializationTests.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ protected RerankExec mutateInstance(RerankExec instance) throws IOException {
4747
return new RerankExec(instance.source(), child, inferenceId, queryText, fields, scoreAttribute());
4848
}
4949

50-
@Override
51-
protected boolean alwaysEmptySource() {
52-
return true;
53-
}
54-
5550
private List<Alias> randomFields() {
5651
return randomList(0, 10, AliasTests::randomAlias);
5752
}

0 commit comments

Comments
 (0)