diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java index b036962c679f1..d15b6aa2973aa 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java @@ -51,6 +51,7 @@ import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; import java.util.ArrayList; @@ -94,6 +95,7 @@ public static List logical() { public static List physical() { return List.of( AggregateExec.ENTRY, + CompletionExec.ENTRY, DissectExec.ENTRY, EnrichExec.ENTRY, EsQueryExec.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java index 85f60b4038a96..3d199fba495c6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java @@ -29,7 +29,7 @@ protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId @Override public void writeTo(StreamOutput out) throws IOException { - Source.EMPTY.writeTo(out); + source().writeTo(out); out.writeNamedWriteable(child()); out.writeNamedWriteable(inferenceId()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java new file mode 100644 index 0000000000000..80887ad08fe69 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeSet; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.UnaryExec; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes; + +public class CompletionExec extends InferenceExec { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + PhysicalPlan.class, + "CompletionExec", + CompletionExec::new + ); + + private final Expression prompt; + private final Attribute targetField; + private List lazyOutput; + + public CompletionExec(Source source, PhysicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) { + super(source, child, inferenceId); + this.prompt = prompt; + this.targetField = targetField; + } + + public CompletionExec(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(PhysicalPlan.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Attribute.class) + ); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeNamedWriteable(prompt); + out.writeNamedWriteable(targetField); + } + + public Expression prompt() { + return prompt; + } + + public Attribute targetField() { + return targetField; + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), prompt, targetField); + } + + @Override + public UnaryExec replaceChild(PhysicalPlan newChild) { + return new CompletionExec(source(), newChild, inferenceId(), prompt, targetField); + } + + @Override + public List output() { + if (lazyOutput == null) { + lazyOutput = mergeOutputAttributes(List.of(targetField), child().output()); + } + + return lazyOutput; + } + + @Override + protected AttributeSet computeReferences() { + return prompt.references(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + CompletionExec completion = (CompletionExec) o; + + return Objects.equals(prompt, completion.prompt) && Objects.equals(targetField, completion.targetField); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), prompt, targetField); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java index 7954690a0fdc0..d60a5ecccc384 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java @@ -30,7 +30,7 @@ public Expression inferenceId() { @Override public void writeTo(StreamOutput out) throws IOException { - Source.EMPTY.writeTo(out); + source().writeTo(out); out.writeNamedWriteable(child()); out.writeNamedWriteable(inferenceId()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java index 6f44634f40ebb..3db455a0c8bb6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval; import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo; @@ -43,6 +44,7 @@ import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec; import org.elasticsearch.xpack.esql.plan.physical.ShowExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec; +import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders; @@ -99,6 +101,10 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) { ); } + if (p instanceof Completion completion) { + return new CompletionExec(completion.source(), child, completion.inferenceId(), completion.prompt(), completion.targetField()); + } + if (p instanceof Enrich enrich) { return new EnrichExec( enrich.source(), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java index d5f7c868f4b47..0b6c1f4eb1b2c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests; +import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -22,9 +22,7 @@ public class CompletionSerializationTests extends AbstractLogicalPlanSerializati @Override protected Completion createTestInstance() { - Source source = randomSource(); - LogicalPlan child = randomChild(0); - return new Completion(source, child, randomInferenceId(), randomPrompt(), randomAttribute()); + return new Completion(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute()); } @Override @@ -43,11 +41,6 @@ protected Completion mutateInstance(Completion instance) throws IOException { return new Completion(instance.source(), child, inferenceId, prompt, targetField); } - @Override - protected boolean alwaysEmptySource() { - return true; - } - private Literal randomInferenceId() { return new Literal(Source.EMPTY, randomIdentifier(), DataType.KEYWORD); } @@ -57,6 +50,6 @@ private Expression randomPrompt() { } private Attribute randomAttribute() { - return FieldAttributeTests.createFieldAttribute(3, randomBoolean()); + return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean()); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/RerankSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/RerankSerializationTests.java index 1bb8bab502f92..22f60c78ec842 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/RerankSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/RerankSerializationTests.java @@ -47,11 +47,6 @@ protected Rerank mutateInstance(Rerank instance) throws IOException { return new Rerank(instance.source(), child, inferenceId, queryText, fields, instance.scoreAttribute()); } - @Override - protected boolean alwaysEmptySource() { - return true; - } - private List randomFields() { return randomList(0, 10, AliasTests::randomAlias); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java new file mode 100644 index 0000000000000..92d2a4b445c6a --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical.inference; + +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; +import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; + +import java.io.IOException; + +public class CompletionExecSerializationTests extends AbstractPhysicalPlanSerializationTests { + @Override + protected CompletionExec createTestInstance() { + return new CompletionExec(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute()); + } + + @Override + protected CompletionExec mutateInstance(CompletionExec instance) throws IOException { + PhysicalPlan child = instance.child(); + Expression inferenceId = instance.inferenceId(); + Expression prompt = instance.prompt(); + Attribute targetField = instance.targetField(); + + switch (between(0, 3)) { + case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); + case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId); + case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt); + case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute); + } + return new CompletionExec(instance.source(), child, inferenceId, prompt, targetField); + } + + private Literal randomInferenceId() { + return new Literal(Source.EMPTY, randomIdentifier(), DataType.KEYWORD); + } + + private Expression randomPrompt() { + return randomBoolean() ? new Literal(Source.EMPTY, randomIdentifier(), DataType.KEYWORD) : randomAttribute(); + } + + private Attribute randomAttribute() { + return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExecSerializationTests.java index ecdbb1a1b4fd0..f5ba1718c7ea0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExecSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExecSerializationTests.java @@ -47,11 +47,6 @@ protected RerankExec mutateInstance(RerankExec instance) throws IOException { return new RerankExec(instance.source(), child, inferenceId, queryText, fields, scoreAttribute()); } - @Override - protected boolean alwaysEmptySource() { - return true; - } - private List randomFields() { return randomList(0, 10, AliasTests::randomAlias); }