From 27ddca38daff8209bd5544ff986496fa427efa6e Mon Sep 17 00:00:00 2001
From: afoucret 
Date: Fri, 12 Sep 2025 09:35:55 +0200
Subject: [PATCH 1/4] Add support for chat_completion to the Completion logical
 plan.
---
 .../org/elasticsearch/TransportVersions.java  |  1 +
 .../xpack/esql/analysis/Analyzer.java         | 17 +------
 .../plan/logical/inference/Completion.java    | 47 ++++++++++++++-----
 .../plan/logical/inference/InferencePlan.java | 33 +++++++++++--
 .../esql/plan/logical/inference/Rerank.java   | 14 +++---
 .../xpack/esql/analysis/AnalyzerTests.java    |  2 +-
 .../optimizer/LogicalPlanOptimizerTests.java  |  1 +
 .../PushDownAndCombineFiltersTests.java       |  2 +
 .../PushDownAndCombineLimitsTests.java        |  9 +++-
 .../CompletionSerializationTests.java         | 28 +++++++++--
 10 files changed, 109 insertions(+), 45 deletions(-)
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index 10114e429cc7a..09759cadaea70 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -324,6 +324,7 @@ static TransportVersion def(int id) {
     public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_156_0_00);
     public static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = def(9_157_0_00);
     public static final TransportVersion INDEX_SOURCE = def(9_158_0_00);
+    public static final TransportVersion ESQL_CHAT_COMPLETION_SUPPORT = def(9_159_0_00);
 
     /*
      * STOP! READ THIS FIRST! No, really,
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
index 1f51055094c92..c8d2de3d1ad7c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
@@ -623,7 +623,7 @@ private LogicalPlan resolveCompletion(Completion p, List childrenOutp
                 prompt = prompt.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
             }
 
-            return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField);
+            return new Completion(p.source(), p.child(), p.inferenceId(), p.taskType(), prompt, targetField);
         }
 
         private LogicalPlan resolveMvExpand(MvExpand p, List childrenOutput) {
@@ -1349,20 +1349,7 @@ private LogicalPlan resolveInferencePlan(InferencePlan> plan, AnalyzerContext
                 return plan.withInferenceResolutionError(inferenceId, error);
             }
 
-            if (resolvedInference.taskType() != plan.taskType()) {
-                String error = "cannot use inference endpoint ["
-                    + inferenceId
-                    + "] with task type ["
-                    + resolvedInference.taskType()
-                    + "] within a "
-                    + plan.nodeName()
-                    + " command. Only inference endpoints with the task type ["
-                    + plan.taskType()
-                    + "] are supported.";
-                return plan.withInferenceResolutionError(inferenceId, error);
-            }
-
-            return plan;
+            return plan.withResolvedInference(resolvedInference);
         }
     }
 
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
index 191664bea9a81..cad7241c584b8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.esql.plan.logical.inference;
 
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -22,6 +23,7 @@
 import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.inference.ResolvedInference;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
 
@@ -37,6 +39,8 @@ public class Completion extends InferencePlan implements TelemetryAw
 
     public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion";
 
+    public static final List SUPPORTED_TASK_TYPES = List.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
+
     public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
         LogicalPlan.class,
         "Completion",
@@ -47,11 +51,18 @@ public class Completion extends InferencePlan implements TelemetryAw
     private List lazyOutput;
 
     public Completion(Source source, LogicalPlan p, Expression prompt, Attribute targetField) {
-        this(source, p, Literal.keyword(Source.EMPTY, DEFAULT_OUTPUT_FIELD_NAME), prompt, targetField);
-    }
-
-    public Completion(Source source, LogicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
-        super(source, child, inferenceId);
+        this(source, p, Literal.NULL, null, prompt, targetField);
+    }
+
+    public Completion(
+        Source source,
+        LogicalPlan child,
+        Expression inferenceId,
+        TaskType taskType,
+        Expression prompt,
+        Attribute targetField
+    ) {
+        super(source, child, inferenceId, taskType);
         this.prompt = prompt;
         this.targetField = targetField;
     }
@@ -61,6 +72,9 @@ public Completion(StreamInput in) throws IOException {
             Source.readFrom((PlanStreamInput) in),
             in.readNamedWriteable(LogicalPlan.class),
             in.readNamedWriteable(Expression.class),
+            in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)
+                ? in.readOptional(input -> TaskType.fromString(input.readString()))
+                : TaskType.COMPLETION,
             in.readNamedWriteable(Expression.class),
             in.readNamedWriteable(Attribute.class)
         );
@@ -69,6 +83,10 @@ public Completion(StreamInput in) throws IOException {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         super.writeTo(out);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)) {
+            out.writeOptional((output, taskType) -> output.writeString(taskType.toString()), taskType());
+        }
+
         out.writeNamedWriteable(prompt);
         out.writeNamedWriteable(targetField);
     }
@@ -87,17 +105,17 @@ public Completion withInferenceId(Expression newInferenceId) {
             return this;
         }
 
-        return new Completion(source(), child(), newInferenceId, prompt, targetField);
+        return new Completion(source(), child(), newInferenceId, taskType(), prompt, targetField);
     }
 
     @Override
-    public Completion replaceChild(LogicalPlan newChild) {
-        return new Completion(source(), newChild, inferenceId(), prompt, targetField);
+    public List supportedTaskTypes() {
+        return SUPPORTED_TASK_TYPES;
     }
 
     @Override
-    public TaskType taskType() {
-        return TaskType.COMPLETION;
+    public Completion replaceChild(LogicalPlan newChild) {
+        return new Completion(source(), newChild, inferenceId(), taskType(), prompt, targetField);
     }
 
     @Override
@@ -122,7 +140,7 @@ public List generatedAttributes() {
     @Override
     public Completion withGeneratedNames(List newNames) {
         checkNumberOfNewNames(newNames);
-        return new Completion(source(), child(), inferenceId(), prompt, this.renameTargetField(newNames.get(0)));
+        return new Completion(source(), child(), inferenceId(), taskType(), prompt, this.renameTargetField(newNames.get(0)));
     }
 
     private Attribute renameTargetField(String newName) {
@@ -133,6 +151,11 @@ private Attribute renameTargetField(String newName) {
         return targetField.withName(newName).withId(new NameId());
     }
 
+    @Override
+    public Completion withResolvedInference(ResolvedInference resolvedInference) {
+        return super.withResolvedInference(resolvedInference);
+    }
+
     @Override
     protected AttributeSet computeReferences() {
         return prompt.references();
@@ -152,7 +175,7 @@ public void postAnalysisVerification(Failures failures) {
 
     @Override
     protected NodeInfo extends LogicalPlan> info() {
-        return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField);
+        return NodeInfo.create(this, Completion::new, child(), inferenceId(), taskType(), prompt, targetField);
     }
 
     @Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
index 633ed74d8addb..70769cd7e0140 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
@@ -12,6 +12,7 @@
 import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
 import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.inference.ResolvedInference;
 import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
 import org.elasticsearch.xpack.esql.plan.logical.ExecutesOn;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
@@ -32,10 +33,12 @@ public abstract class InferencePlan> ex
     public static final List VALID_INFERENCE_OPTION_NAMES = List.of(INFERENCE_ID_OPTION_NAME);
 
     private final Expression inferenceId;
+    private final TaskType taskType;
 
-    protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId) {
+    protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId, TaskType taskType) {
         super(source, child);
         this.inferenceId = inferenceId;
+        this.taskType = taskType;
     }
 
     @Override
@@ -60,18 +63,40 @@ public boolean equals(Object o) {
         if (o == null || getClass() != o.getClass()) return false;
         if (super.equals(o) == false) return false;
         InferencePlan> other = (InferencePlan>) o;
-        return Objects.equals(inferenceId(), other.inferenceId());
+        return Objects.equals(inferenceId(), other.inferenceId()) && taskType == other.taskType;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), inferenceId());
+        return Objects.hash(super.hashCode(), inferenceId(), taskType);
     }
 
-    public abstract TaskType taskType();
+    public TaskType taskType() {
+        return taskType;
+    }
 
     public abstract PlanType withInferenceId(Expression newInferenceId);
 
+    public abstract List supportedTaskTypes();
+
+    @SuppressWarnings("unchecked")
+    public PlanType withResolvedInference(ResolvedInference resolvedInference) {
+        if (supportedTaskTypes().stream().noneMatch(resolvedInference.taskType()::equals)) {
+            String error = "cannot use inference endpoint ["
+                + resolvedInference.inferenceId()
+                + "] with task type ["
+                + resolvedInference.taskType()
+                + "] within a "
+                + nodeName()
+                + " command. Only inference endpoints with the task type "
+                + supportedTaskTypes()
+                + " are supported.";
+            return withInferenceResolutionError(resolvedInference.inferenceId(), error);
+        }
+
+        return (PlanType) this;
+    }
+
     public PlanType withInferenceResolutionError(String inferenceId, String error) {
         return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
     }
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java
index 6f86138397fa6..25894c75f2289 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java
@@ -59,7 +59,7 @@ public Rerank(
         List rerankFields,
         Attribute scoreAttribute
     ) {
-        super(source, child, inferenceId);
+        super(source, child, inferenceId, TaskType.RERANK);
         this.queryText = queryText;
         this.rerankFields = rerankFields;
         this.scoreAttribute = scoreAttribute;
@@ -96,11 +96,6 @@ public Attribute scoreAttribute() {
         return scoreAttribute;
     }
 
-    @Override
-    public TaskType taskType() {
-        return TaskType.RERANK;
-    }
-
     @Override
     public Rerank withInferenceId(Expression newInferenceId) {
         if (inferenceId().equals(newInferenceId)) {
@@ -109,6 +104,11 @@ public Rerank withInferenceId(Expression newInferenceId) {
         return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute);
     }
 
+    @Override
+    public List supportedTaskTypes() {
+        return List.of(TaskType.RERANK);
+    }
+
     public Rerank withRerankFields(List newRerankFields) {
         if (rerankFields.equals(newRerankFields)) {
             return this;
@@ -163,7 +163,7 @@ public static AttributeSet computeReferences(List fields) {
     }
 
     public boolean isValidRerankField(Alias rerankField) {
-        // Only supportinng the following datatypes for now: text, numeric and boolean
+        // Only supporting the following datatypes for now: text, numeric and boolean
         return DataType.isString(rerankField.dataType())
             || rerankField.dataType() == DataType.BOOLEAN
             || rerankField.dataType().isNumeric();
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index 417eb0f1a7834..cf89671a2543e 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -4026,7 +4026,7 @@ public void testResolveCompletionInferenceIdInvalidTaskType() {
             "mapping-books.json",
             new QueryParams(),
             "cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command."
-                + " Only inference endpoints with the task type [completion] are supported"
+                + " Only inference endpoints with the task type [completion, chat_completion] are supported"
         );
     }
 
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
index 1a181fe805e81..8a912f3d7c11c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
@@ -5488,6 +5488,7 @@ record PushdownShadowingGeneratingPlanTestCase(
                 EMPTY,
                 plan,
                 randomLiteral(TEXT),
+                randomFrom(Completion.SUPPORTED_TASK_TYPES),
                 new Concat(EMPTY, randomLiteral(TEXT), List.of(attr)),
                 new ReferenceAttribute(EMPTY, "y", KEYWORD)
             ),
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java
index 11c64a82e3f57..988718ff365c2 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java
@@ -303,6 +303,7 @@ public void testPushDownFilterPastCompletion() {
                 EMPTY,
                 new Filter(EMPTY, relation, new And(EMPTY, conditionA, conditionB)),
                 completion.inferenceId(),
+                completion.taskType(),
                 completion.prompt(),
                 completion.targetField()
             ),
@@ -350,6 +351,7 @@ private static Completion completion(LogicalPlan child) {
             EMPTY,
             child,
             randomLiteral(DataType.KEYWORD),
+            randomFrom(Completion.SUPPORTED_TASK_TYPES),
             randomLiteral(randomBoolean() ? DataType.TEXT : DataType.KEYWORD),
             referenceAttribute(randomIdentifier(), DataType.KEYWORD)
         );
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java
index b1626e4b77ce8..973b770bedac3 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java
@@ -77,7 +77,14 @@ public void checkOptimizedPlan(LogicalPlan basePlan, LogicalPlan optimizedPlan)
         ),
         new PushDownLimitTestCase<>(
             Completion.class,
-            (plan, attr) -> new Completion(EMPTY, plan, randomLiteral(KEYWORD), randomLiteral(KEYWORD), attr),
+            (plan, attr) -> new Completion(
+                EMPTY,
+                plan,
+                randomLiteral(KEYWORD),
+                randomFrom(Completion.SUPPORTED_TASK_TYPES),
+                randomLiteral(KEYWORD),
+                attr
+            ),
             (basePlan, optimizedPlan) -> {
                 assertEquals(basePlan.source(), optimizedPlan.source());
                 assertEquals(basePlan.inferenceId(), optimizedPlan.inferenceId());
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java
index e9810454224aa..732335f958fc4 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java
@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.esql.plan.logical.inference;
 
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.esql.core.expression.Attribute;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.expression.Literal;
@@ -21,7 +22,14 @@ public class CompletionSerializationTests extends AbstractLogicalPlanSerializati
 
     @Override
     protected Completion createTestInstance() {
-        return new Completion(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute());
+        return new Completion(
+            randomSource(),
+            randomChild(0),
+            randomInferenceId(),
+            randomTaskTypeOrNull(),
+            randomPrompt(),
+            randomAttribute()
+        );
     }
 
     @Override
@@ -30,14 +38,16 @@ protected Completion mutateInstance(Completion instance) throws IOException {
         Expression inferenceId = instance.inferenceId();
         Expression prompt = instance.prompt();
         Attribute targetField = instance.targetField();
+        TaskType taskType = instance.taskType();
 
-        switch (between(0, 3)) {
+        switch (between(0, 4)) {
             case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
             case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId);
-            case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
-            case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
+            case 2 -> taskType = randomValueOtherThan(taskType, this::randomTaskTypeOrNull);
+            case 3 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
+            case 4 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
         }
-        return new Completion(instance.source(), child, inferenceId, prompt, targetField);
+        return new Completion(instance.source(), child, inferenceId, taskType, prompt, targetField);
     }
 
     private Literal randomInferenceId() {
@@ -51,4 +61,12 @@ private Expression randomPrompt() {
     private Attribute randomAttribute() {
         return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
     }
+
+    private TaskType randomTaskType() {
+        return randomFrom(Completion.SUPPORTED_TASK_TYPES);
+    }
+
+    private TaskType randomTaskTypeOrNull() {
+        return randomBoolean() ? randomTaskType() : null;
+    }
 }
From 7d986e6f7429acafef7add958b2dc0d4efabb42d Mon Sep 17 00:00:00 2001
From: afoucret 
Date: Fri, 12 Sep 2025 09:47:15 +0200
Subject: [PATCH 2/4] Add support for chat_completion to the Completion
 physical plan.
---
 .../plan/logical/inference/Completion.java    |  1 -
 .../physical/inference/CompletionExec.java    | 23 +++++++++++++++----
 .../physical/inference/InferenceExec.java     | 13 ++++++++---
 .../plan/physical/inference/RerankExec.java   |  3 ++-
 .../esql/planner/mapper/MapperUtils.java      |  9 +++++++-
 .../CompletionExecSerializationTests.java     | 21 +++++++++++++----
 6 files changed, 55 insertions(+), 15 deletions(-)
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
index cad7241c584b8..563bd105fc7cb 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
@@ -86,7 +86,6 @@ public void writeTo(StreamOutput out) throws IOException {
         if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)) {
             out.writeOptional((output, taskType) -> output.writeString(taskType.toString()), taskType());
         }
-
         out.writeNamedWriteable(prompt);
         out.writeNamedWriteable(targetField);
     }
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java
index 80887ad08fe69..fc8d09b83566e 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java
@@ -7,9 +7,11 @@
 
 package org.elasticsearch.xpack.esql.plan.physical.inference;
 
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.esql.core.expression.Attribute;
 import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
@@ -37,8 +39,15 @@ public class CompletionExec extends InferenceExec {
     private final Attribute targetField;
     private List lazyOutput;
 
-    public CompletionExec(Source source, PhysicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
-        super(source, child, inferenceId);
+    public CompletionExec(
+        Source source,
+        PhysicalPlan child,
+        Expression inferenceId,
+        TaskType taskType,
+        Expression prompt,
+        Attribute targetField
+    ) {
+        super(source, child, inferenceId, taskType);
         this.prompt = prompt;
         this.targetField = targetField;
     }
@@ -48,6 +57,9 @@ public CompletionExec(StreamInput in) throws IOException {
             Source.readFrom((PlanStreamInput) in),
             in.readNamedWriteable(PhysicalPlan.class),
             in.readNamedWriteable(Expression.class),
+            in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)
+                ? TaskType.fromString(in.readString())
+                : TaskType.COMPLETION,
             in.readNamedWriteable(Expression.class),
             in.readNamedWriteable(Attribute.class)
         );
@@ -61,6 +73,9 @@ public String getWriteableName() {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         super.writeTo(out);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHAT_COMPLETION_SUPPORT)) {
+            out.writeString(taskType().toString());
+        }
         out.writeNamedWriteable(prompt);
         out.writeNamedWriteable(targetField);
     }
@@ -75,12 +90,12 @@ public Attribute targetField() {
 
     @Override
     protected NodeInfo extends PhysicalPlan> info() {
-        return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), prompt, targetField);
+        return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), taskType(), prompt, targetField);
     }
 
     @Override
     public UnaryExec replaceChild(PhysicalPlan newChild) {
-        return new CompletionExec(source(), newChild, inferenceId(), prompt, targetField);
+        return new CompletionExec(source(), newChild, inferenceId(), taskType(), prompt, targetField);
     }
 
     @Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
index d60a5ecccc384..de8c673728056 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.esql.plan.physical.inference;
 
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
@@ -18,16 +19,22 @@
 
 public abstract class InferenceExec extends UnaryExec {
     private final Expression inferenceId;
+    private final TaskType taskType;
 
-    protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId) {
+    protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId, TaskType taskType) {
         super(source, child);
         this.inferenceId = inferenceId;
+        this.taskType = taskType;
     }
 
     public Expression inferenceId() {
         return inferenceId;
     }
 
+    public TaskType taskType() {
+        return taskType;
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         source().writeTo(out);
@@ -41,11 +48,11 @@ public boolean equals(Object o) {
         if (o == null || getClass() != o.getClass()) return false;
         if (super.equals(o) == false) return false;
         InferenceExec that = (InferenceExec) o;
-        return inferenceId.equals(that.inferenceId);
+        return inferenceId.equals(that.inferenceId) && taskType == that.taskType;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), inferenceId());
+        return Objects.hash(super.hashCode(), inferenceId(), taskType);
     }
 }
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java
index ad852d0ac20db..28cb18b17ecdc 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java
@@ -10,6 +10,7 @@
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.esql.core.expression.Alias;
 import org.elasticsearch.xpack.esql.core.expression.Attribute;
 import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
@@ -48,7 +49,7 @@ public RerankExec(
         List rerankFields,
         Attribute scoreAttribute
     ) {
-        super(source, child, inferenceId);
+        super(source, child, inferenceId, TaskType.RERANK);
         this.queryText = queryText;
         this.rerankFields = rerankFields;
         this.scoreAttribute = scoreAttribute;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
index aabb18326fe11..05cee4b6a4adf 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
@@ -103,7 +103,14 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) {
         }
 
         if (p instanceof Completion completion) {
-            return new CompletionExec(completion.source(), child, completion.inferenceId(), completion.prompt(), completion.targetField());
+            return new CompletionExec(
+                completion.source(),
+                child,
+                completion.inferenceId(),
+                completion.taskType(),
+                completion.prompt(),
+                completion.targetField()
+            );
         }
 
         if (p instanceof Enrich enrich) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java
index 9fd41a2432462..486ae323c982c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java
@@ -7,11 +7,13 @@
 
 package org.elasticsearch.xpack.esql.plan.physical.inference;
 
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.esql.core.expression.Attribute;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.expression.Literal;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
 import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests;
 import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
 
@@ -20,7 +22,14 @@
 public class CompletionExecSerializationTests extends AbstractPhysicalPlanSerializationTests {
     @Override
     protected CompletionExec createTestInstance() {
-        return new CompletionExec(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute());
+        return new CompletionExec(
+            randomSource(),
+            randomChild(0),
+            randomInferenceId(),
+            randomFrom(Completion.SUPPORTED_TASK_TYPES),
+            randomPrompt(),
+            randomAttribute()
+        );
     }
 
     @Override
@@ -29,14 +38,16 @@ protected CompletionExec mutateInstance(CompletionExec instance) throws IOExcept
         Expression inferenceId = instance.inferenceId();
         Expression prompt = instance.prompt();
         Attribute targetField = instance.targetField();
+        TaskType taskType = instance.taskType();
 
-        switch (between(0, 3)) {
+        switch (between(0, 4)) {
             case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
             case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId);
-            case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
-            case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
+            case 2 -> taskType = randomValueOtherThan(taskType, () -> randomFrom(Completion.SUPPORTED_TASK_TYPES));
+            case 3 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
+            case 4 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
         }
-        return new CompletionExec(instance.source(), child, inferenceId, prompt, targetField);
+        return new CompletionExec(instance.source(), child, inferenceId, taskType, prompt, targetField);
     }
 
     private Literal randomInferenceId() {
From bc6eaaaa607b1cf3a3cd3a0caab315ed6107ba9f Mon Sep 17 00:00:00 2001
From: afoucret 
Date: Fri, 12 Sep 2025 19:10:22 +0200
Subject: [PATCH 3/4] Implements streaming support in the completion operator.
---
 .../bulk/BulkInferenceRequestItem.java        | 113 +++++++++++++
 .../bulk/BulkInferenceRequestIterator.java    |   6 +-
 .../inference/bulk/BulkInferenceRunner.java   |  87 ++++++----
 .../bulk/BulkInferenceStreamingHandler.java   | 156 ++++++++++++++++++
 .../completion/ChatCompletionOperator.java    |  94 +++++++++++
 ...ChatCompletionOperatorRequestIterator.java |  88 ++++++++++
 .../completion/CompletionOperator.java        |   2 +-
 .../CompletionOperatorRequestIterator.java    |  56 +------
 .../inference/completion/PromptReader.java    |  56 +++++++
 .../rerank/RerankOperatorRequestIterator.java |   7 +-
 .../plan/logical/inference/Completion.java    |  15 +-
 .../physical/inference/InferenceExec.java     |   4 +-
 .../esql/planner/LocalExecutionPlanner.java   |   8 +-
 .../bulk/BulkInferenceRunnerTests.java        |   5 +-
 ...ompletionOperatorRequestIteratorTests.java |   2 +-
 .../RerankOperatorRequestIteratorTests.java   |   2 +-
 16 files changed, 599 insertions(+), 102 deletions(-)
 create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java
 create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java
 create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java
 create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java
 create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java
new file mode 100644
index 0000000000000..d2a9c51182968
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.bulk;
+
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
+
+import java.util.Objects;
+
+sealed public interface BulkInferenceRequestItem permits
+    BulkInferenceRequestItem.AbstractBulkInferenceRequestItem {
+
+    TaskType taskType();
+
+    T inferenceRequest();
+
+    BulkInferenceRequestItem withSeqNo(long seqNo);
+
+    Long seqNo();
+
+    static InferenceRequestItem from(InferenceAction.Request request) {
+        return new InferenceRequestItem(request);
+    }
+
+    static ChatCompletionRequestItem from(UnifiedCompletionAction.Request request) {
+        return new ChatCompletionRequestItem(request);
+    }
+
+    abstract sealed class AbstractBulkInferenceRequestItem implements BulkInferenceRequestItem
+        permits InferenceRequestItem, ChatCompletionRequestItem {
+        private final T request;
+        private final Long seqNo;
+
+        protected AbstractBulkInferenceRequestItem(T request) {
+            this(request, null);
+        }
+
+        protected AbstractBulkInferenceRequestItem(T request, Long seqNo) {
+            this.request = request;
+            this.seqNo = seqNo;
+        }
+
+        @Override
+        public T inferenceRequest() {
+            return request;
+        }
+
+        @Override
+        public Long seqNo() {
+            return seqNo;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (o == null || getClass() != o.getClass()) return false;
+            AbstractBulkInferenceRequestItem> that = (AbstractBulkInferenceRequestItem>) o;
+            return Objects.equals(request, that.request) && Objects.equals(seqNo, that.seqNo);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(request, seqNo);
+        }
+
+        @Override
+        public TaskType taskType() {
+            return request.getTaskType();
+        }
+    }
+
+    final class InferenceRequestItem extends AbstractBulkInferenceRequestItem {
+        private InferenceRequestItem(InferenceAction.Request request) {
+            this(request, null);
+        }
+
+        private InferenceRequestItem(InferenceAction.Request request, Long seqNo) {
+            super(request, seqNo);
+        }
+
+        @Override
+        public InferenceRequestItem withSeqNo(long seqNo) {
+            return new InferenceRequestItem(inferenceRequest(), seqNo);
+        }
+    }
+
+    final class ChatCompletionRequestItem extends AbstractBulkInferenceRequestItem {
+
+        private ChatCompletionRequestItem(UnifiedCompletionAction.Request request) {
+            this(request, null);
+        }
+
+        private ChatCompletionRequestItem(UnifiedCompletionAction.Request request, Long seqNo) {
+            super(request, seqNo);
+        }
+
+        @Override
+        public TaskType taskType() {
+            return TaskType.CHAT_COMPLETION;
+        }
+
+        @Override
+        public ChatCompletionRequestItem withSeqNo(long seqNo) {
+            return new ChatCompletionRequestItem(inferenceRequest(), seqNo);
+        }
+    }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java
index 7327b182d0b6c..03739af92f7ff 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java
@@ -8,17 +8,15 @@
 package org.elasticsearch.xpack.esql.inference.bulk;
 
 import org.elasticsearch.core.Releasable;
-import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 
 import java.util.Iterator;
 
-public interface BulkInferenceRequestIterator extends Iterator, Releasable {
+public interface BulkInferenceRequestIterator extends Iterator>, Releasable {
 
     /**
      * Returns an estimate of the number of requests that will be produced.
      *
-     * This is typically used to pre-allocate buffers or output to th appropriate size.
+     * This is typically used to pre-allocate buffers or output to the appropriate size.
      */
     int estimatedSize();
-
 }
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java
index 203a3031bcad4..4d35411e3d3bb 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java
@@ -10,8 +10,10 @@
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -175,12 +177,12 @@ private class BulkInferenceRequest {
          * to the request iterator.
          * 
          *
-         * @return A BulkRequestItem if a request and permit are available, null otherwise
+         * @return A BulkInferenceRequestItem if a request and permit are available, null otherwise
          */
-        private BulkRequestItem pollPendingRequest() {
+        private BulkInferenceRequestItem> pollPendingRequest() {
             synchronized (requests) {
                 if (requests.hasNext()) {
-                    return new BulkRequestItem(executionState.generateSeqNo(), requests.next());
+                    return requests.next().withSeqNo(executionState.generateSeqNo());
                 }
             }
 
@@ -226,7 +228,7 @@ private void executePendingRequests(int recursionDepth) {
                         }
                         return;
                     } else {
-                        BulkRequestItem bulkRequestItem = pollPendingRequest();
+                        BulkInferenceRequestItem> bulkRequestItem = pollPendingRequest();
 
                         if (bulkRequestItem == null) {
                             // No more requests available
@@ -234,14 +236,14 @@ private void executePendingRequests(int recursionDepth) {
                             permits.release();
 
                             // Check if another bulk request is pending for execution.
-                            BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
+                            BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll();
 
-                            while (nexBulkRequest == this) {
-                                nexBulkRequest = pendingBulkRequests.poll();
+                            while (nextBulkRequest == this) {
+                                nextBulkRequest = pendingBulkRequests.poll();
                             }
 
-                            if (nexBulkRequest != null) {
-                                executor.execute(nexBulkRequest::executePendingRequests);
+                            if (nextBulkRequest != null) {
+                                executor.execute(nextBulkRequest::executePendingRequests);
                             }
 
                             return;
@@ -275,9 +277,9 @@ private void executePendingRequests(int recursionDepth) {
                                         // Response has already been sent
                                         // No need to continue processing this bulk.
                                         // Check if another bulk request is pending for execution.
-                                        BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
-                                        if (nexBulkRequest != null) {
-                                            executor.execute(nexBulkRequest::executePendingRequests);
+                                        BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll();
+                                        if (nextBulkRequest != null) {
+                                            executor.execute(nextBulkRequest::executePendingRequests);
                                         }
                                         return;
                                     }
@@ -298,19 +300,26 @@ private void executePendingRequests(int recursionDepth) {
                         );
 
                         // Handle null requests (edge case in some iterators)
-                        if (bulkRequestItem.request() == null) {
+                        if (bulkRequestItem.inferenceRequest() == null) {
                             inferenceResponseListener.onResponse(null);
                             return;
                         }
 
                         // Execute the inference request with proper origin context
-                        executeAsyncWithOrigin(
-                            client,
-                            INFERENCE_ORIGIN,
-                            InferenceAction.INSTANCE,
-                            bulkRequestItem.request(),
-                            inferenceResponseListener
-                        );
+                        if (bulkRequestItem.taskType() == TaskType.CHAT_COMPLETION) {
+                            handleStreamingRequest(
+                                (UnifiedCompletionAction.Request) bulkRequestItem.inferenceRequest(),
+                                inferenceResponseListener
+                            );
+                        } else {
+                            executeAsyncWithOrigin(
+                                client,
+                                INFERENCE_ORIGIN,
+                                InferenceAction.INSTANCE,
+                                bulkRequestItem.inferenceRequest(),
+                                inferenceResponseListener
+                            );
+                        }
                     }
                 }
             } catch (Exception e) {
@@ -318,6 +327,30 @@ private void executePendingRequests(int recursionDepth) {
             }
         }
 
+        /**
+         * Handles streaming inference requests for chat completion tasks.
+         * 
+         * This method executes UnifiedCompletionAction requests and sets up proper streaming
+         * response handling through the BulkInferenceStreamingHandler. The streaming handler
+         * manages the asynchronous stream processing and ensures responses are properly
+         * delivered to the completion listener.
+         * 
+         *
+         * @param request  The UnifiedCompletionAction request to execute
+         * @param listener The listener to receive the final aggregated response
+         */
+        private void handleStreamingRequest(UnifiedCompletionAction.Request request, ActionListener listener) {
+            executeAsyncWithOrigin(
+                client,
+                INFERENCE_ORIGIN,
+                UnifiedCompletionAction.INSTANCE,
+                request,
+                listener.delegateFailureAndWrap((l, inferenceResponse) -> {
+                    inferenceResponse.publisher().subscribe(new BulkInferenceStreamingHandler(l));
+                })
+            );
+        }
+
         /**
          * Processes and delivers buffered responses in order, ensuring proper sequencing.
          * 
@@ -360,20 +393,6 @@ private void onBulkCompletion() {
         }
     }
 
-    /**
-     * Encapsulates an inference request with its associated sequence number.
-     * 
-     * The sequence number is used for ordering responses and tracking completion
-     * in the bulk execution state.
-     * 
-     *
-     * @param seqNo   Unique sequence number for this request in the bulk operation
-     * @param request The actual inference request to execute
-     */
-    private record BulkRequestItem(long seqNo, InferenceAction.Request request) {
-
-    }
-
     public static Factory factory(Client client) {
         return inferenceRunnerConfig -> new BulkInferenceRunner(client, inferenceRunnerConfig.maxOutstandingBulkRequests());
     }
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java
new file mode 100644
index 0000000000000..ab693a0787a52
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceStreamingHandler.java
@@ -0,0 +1,156 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.bulk;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
+
+import java.util.List;
+import java.util.concurrent.Flow;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Handles streaming inference responses for chat completion requests in bulk inference operations.
+ * 
+ * This class implements the Reactive Streams {@link Flow.Subscriber} interface to process
+ * streaming inference results from chat completion services. It accumulates content from
+ * streaming chunks and delivers the final aggregated response to the completion listener.
+ * 
+ * Ultimately, it constructs a {@link ChatCompletionResults} object containing the
+ * complete content from all streaming chunks received during the inference operation.
+ * By doing so, the result of chat_completion requests is made available is the same as the legacu completion
+ * and can be consumed in the same way.
+ * 
+ */
+class BulkInferenceStreamingHandler implements Flow.Subscriber {
+
+    /**
+     * Flag to track whether this streaming session has completed to prevent duplicate processing.
+     */
+    private final AtomicBoolean isLastPart = new AtomicBoolean(false);
+
+    /**
+     * The subscription handle for controlling the flow of streaming data.
+     */
+    private Flow.Subscription subscription;
+
+    /**
+     * Buffer for accumulating content from streaming chunks into the final response.
+     */
+    private final StringBuilder resultBuilder = new StringBuilder();
+
+    /**
+     * Listener to receive the final aggregated inference response.
+     */
+    private final ActionListener inferenceResponseListener;
+
+    /**
+     * Creates a new streaming handler for processing inference chat_completion responses.
+     *
+     * @param inferenceResponseListener The listener that will receive the final aggregated response
+     *                                  once all streaming chunks have been processed
+     */
+    BulkInferenceStreamingHandler(ActionListener inferenceResponseListener) {
+        this.inferenceResponseListener = inferenceResponseListener;
+    }
+
+    /**
+     * Called when the streaming publisher is ready to start sending data.
+     * 
+     * This method establishes the subscription and requests the first chunk of data.
+     * If the streaming session has already completed, it cancels the subscription
+     * to prevent resource leaks.
+     * 
+     *
+     * @param subscription The subscription handle for controlling data flow
+     */
+    @Override
+    public void onSubscribe(Flow.Subscription subscription) {
+        if (isLastPart.get() == false) {
+            this.subscription = subscription;
+            subscription.request(1);
+        } else {
+            subscription.cancel();
+        }
+    }
+
+    /**
+     * Processes each streaming chunk as it arrives from the inference service.
+     * 
+     * This method extracts content from streaming chat completion chunks and accumulates
+     * it in the result builder. It handles the specific structure of streaming unified
+     * chat completion results, extracting text content from delta objects within choices.
+     * 
+     * 
+     * After processing each chunk, it requests the next chunk from the subscription
+     * to continue the streaming process.
+     * 
+     *
+     * @param item The streaming result item containing chunk data from the inference service
+     */
+    @Override
+    public void onNext(InferenceServiceResults.Result item) {
+        if (isLastPart.get() == false) {
+            if (item instanceof StreamingUnifiedChatCompletionResults.Results streamingChunkResults) {
+                for (var chunk : streamingChunkResults.chunks()) {
+                    for (var choice : chunk.choices()) {
+                        if (choice.delta() != null && choice.delta().content() != null) {
+                            resultBuilder.append(choice.delta().content());
+                        }
+                    }
+                }
+                subscription.request(1);
+            } else {
+                // Handle unexpected result types by requesting the next item
+                subscription.request(1);
+            }
+        }
+    }
+
+    /**
+     * Called when an error occurs during streaming processing.
+     * 
+     * This method ensures that errors are properly propagated to the inference listener
+     * and that the streaming session is marked as completed to prevent further processing.
+     * 
+     *
+     * @param throwable The error that occurred during streaming
+     */
+    @Override
+    public void onError(Throwable throwable) {
+        if (isLastPart.compareAndSet(false, true)) {
+            inferenceResponseListener.onFailure(new RuntimeException("Streaming inference failed", throwable));
+        }
+    }
+
+    /**
+     * Called when the streaming process completes successfully.
+     * 
+     * This method finalizes the streaming process by creating a complete inference response
+     * from the accumulated content and delivering it to the listener. It constructs a
+     * {@link ChatCompletionResults} object containing the aggregated content from all
+     * streaming chunks.
+     * 
+     */
+    @Override
+    public void onComplete() {
+        if (isLastPart.compareAndSet(false, true)) {
+            // Create the final aggregated response from accumulated content
+            String finalContent = resultBuilder.toString();
+            ChatCompletionResults.Result completionResult = new ChatCompletionResults.Result(finalContent);
+            ChatCompletionResults chatResults = new ChatCompletionResults(List.of(completionResult));
+            InferenceAction.Response response = new InferenceAction.Response(chatResults);
+
+            // Deliver the final response to the listener
+            inferenceResponseListener.onResponse(response);
+        }
+    }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java
new file mode 100644
index 0000000000000..89dc03f86b953
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperator.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.completion;
+
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
+import org.elasticsearch.compute.operator.Operator;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.xpack.esql.inference.InferenceOperator;
+import org.elasticsearch.xpack.esql.inference.InferenceService;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
+
+/**
+ * {@link ChatCompletionOperator} is an {@link InferenceOperator} that performs inference using chat_completion inference endpoints.
+ * It evaluates a prompt expression for each input row, constructs inference requests, and emits the model responses as output.
+ */
+public class ChatCompletionOperator extends InferenceOperator {
+
+    private final ExpressionEvaluator promptEvaluator;
+
+    public ChatCompletionOperator(
+        DriverContext driverContext,
+        BulkInferenceRunner bulkInferenceRunner,
+        String inferenceId,
+        ExpressionEvaluator promptEvaluator,
+        int maxOutstandingPages
+    ) {
+        super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages);
+        this.promptEvaluator = promptEvaluator;
+    }
+
+    @Override
+    protected void doClose() {
+        Releasables.close(promptEvaluator);
+    }
+
+    @Override
+    public String toString() {
+        return "ChatCompletionOperator[inference_id=[" + inferenceId() + "]]";
+    }
+
+    /**
+     * Constructs the chat completion inference requests iterator for the given input page by evaluating the prompt expression.
+     *
+     * @param inputPage The input data page.
+     */
+    @Override
+    protected BulkInferenceRequestIterator requests(Page inputPage) {
+        return new ChatCompletionOperatorRequestIterator((BytesRefBlock) promptEvaluator.eval(inputPage), inferenceId());
+    }
+
+    /**
+     * Creates a new {@link CompletionOperatorOutputBuilder} to collect and emit the chat completion results.
+     *
+     * @param input The input page for which results will be constructed.
+     */
+    @Override
+    protected CompletionOperatorOutputBuilder outputBuilder(Page input) {
+        BytesRefBlock.Builder outputBlockBuilder = blockFactory().newBytesRefBlockBuilder(input.getPositionCount());
+        return new CompletionOperatorOutputBuilder(outputBlockBuilder, input);
+    }
+
+    /**
+     * Factory for creating {@link ChatCompletionOperator} instances.
+     */
+    public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory)
+        implements
+            OperatorFactory {
+        @Override
+        public String describe() {
+            return "ChatCompletionOperator[inference_id=[" + inferenceId + "]]";
+        }
+
+        @Override
+        public Operator get(DriverContext driverContext) {
+            return new ChatCompletionOperator(
+                driverContext,
+                inferenceService.bulkInferenceRunner(),
+                inferenceId,
+                promptEvaluatorFactory.get(driverContext),
+                BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests()
+            );
+        }
+    }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java
new file mode 100644
index 0000000000000..7d0e0e3c6edd4
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/ChatCompletionOperatorRequestIterator.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.completion;
+
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestItem;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
+
+import java.util.List;
+import java.util.NoSuchElementException;
+
+/**
+ * This iterator reads prompts from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances
+ * of type {@link TaskType#CHAT_COMPLETION}.
+ */
+public class ChatCompletionOperatorRequestIterator implements BulkInferenceRequestIterator {
+
+    private final PromptReader promptReader;
+    private final String inferenceId;
+    private final int size;
+    private int currentPos = 0;
+
+    /**
+     * Constructs a new iterator from the given block of prompts.
+     *
+     * @param promptBlock The input block containing prompts.
+     * @param inferenceId The ID of the inference model to invoke.
+     */
+    public ChatCompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
+        this.promptReader = new PromptReader(promptBlock);
+        this.size = promptBlock.getPositionCount();
+        this.inferenceId = inferenceId;
+    }
+
+    @Override
+    public boolean hasNext() {
+        return currentPos < size;
+    }
+
+    @Override
+    public BulkInferenceRequestItem.ChatCompletionRequestItem next() {
+        if (hasNext() == false) {
+            throw new NoSuchElementException();
+        }
+
+        UnifiedCompletionAction.Request inferenceRequest = inferenceRequest(promptReader.readPrompt(currentPos++));
+        return BulkInferenceRequestItem.from(inferenceRequest);
+    }
+
+    /**
+     * Wraps a single prompt string into an {@link UnifiedCompletionRequest}.
+     */
+    private UnifiedCompletionAction.Request inferenceRequest(String prompt) {
+        if (prompt == null) {
+            return null;
+        }
+
+        return new UnifiedCompletionAction.Request(
+            inferenceId,
+            TaskType.CHAT_COMPLETION,
+            UnifiedCompletionRequest.of(
+                List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(prompt), "user", null, null))
+            ),
+            TimeValue.THIRTY_SECONDS
+        );
+    }
+
+    @Override
+    public int estimatedSize() {
+        return promptReader.estimatedSize();
+    }
+
+    @Override
+    public void close() {
+        Releasables.close(promptReader);
+    }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java
index 65b560f3cf9ce..023c4670b6bc0 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java
@@ -20,7 +20,7 @@
 import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
 
 /**
- * {@link CompletionOperator} is an {@link InferenceOperator} that performs inference using prompt-based model (e.g., text completion).
+ * {@link CompletionOperator} is an {@link InferenceOperator} that performs inference using completion inference endpoints.
  * It evaluates a prompt expression for each input row, constructs inference requests, and emits the model responses as output.
  */
 public class CompletionOperator extends InferenceOperator {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java
index f526cd9edb077..dc9922b3ecd55 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java
@@ -7,12 +7,11 @@
 
 package org.elasticsearch.xpack.esql.inference.completion;
 
-import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.compute.data.BytesRefBlock;
-import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestItem;
 import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
 
 import java.util.List;
@@ -47,12 +46,12 @@ public boolean hasNext() {
     }
 
     @Override
-    public InferenceAction.Request next() {
+    public BulkInferenceRequestItem next() {
         if (hasNext() == false) {
             throw new NoSuchElementException();
         }
 
-        return inferenceRequest(promptReader.readPrompt(currentPos++));
+        return BulkInferenceRequestItem.from(inferenceRequest(promptReader.readPrompt(currentPos++)));
     }
 
     /**
@@ -75,53 +74,4 @@ public int estimatedSize() {
     public void close() {
         Releasables.close(promptReader);
     }
-
-    /**
-     * Helper class that reads prompts from a {@link BytesRefBlock}.
-     */
-    private static class PromptReader implements Releasable {
-        private final BytesRefBlock promptBlock;
-        private final StringBuilder strBuilder = new StringBuilder();
-        private BytesRef readBuffer = new BytesRef();
-
-        private PromptReader(BytesRefBlock promptBlock) {
-            this.promptBlock = promptBlock;
-        }
-
-        /**
-         * Reads the prompt string at the given position..
-         *
-         * @param pos the position index in the block
-         */
-        public String readPrompt(int pos) {
-            if (promptBlock.isNull(pos)) {
-                return null;
-            }
-
-            strBuilder.setLength(0);
-
-            for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
-                readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
-                strBuilder.append(readBuffer.utf8ToString());
-                if (valueIndex != promptBlock.getValueCount(pos) - 1) {
-                    strBuilder.append("\n");
-                }
-            }
-
-            return strBuilder.toString();
-        }
-
-        /**
-         * Returns the total number of positions (prompts) in the block.
-         */
-        public int estimatedSize() {
-            return promptBlock.getPositionCount();
-        }
-
-        @Override
-        public void close() {
-            promptBlock.allowPassingToDifferentDriver();
-            Releasables.close(promptBlock);
-        }
-    }
 }
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java
new file mode 100644
index 0000000000000..58c509fd6f021
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/PromptReader.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.completion;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+
+public class PromptReader implements Releasable {
+    private final BytesRefBlock promptBlock;
+    private final StringBuilder strBuilder = new StringBuilder();
+    private BytesRef readBuffer = new BytesRef();
+
+    public PromptReader(BytesRefBlock promptBlock) {
+        this.promptBlock = promptBlock;
+    }
+
+    /**
+     * Reads the prompt string at the given position.
+     *
+     * @param pos the position index in the block
+     */
+    public String readPrompt(int pos) {
+        if (promptBlock.isNull(pos)) {
+            return null;
+        }
+        strBuilder.setLength(0);
+        for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
+            readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
+            strBuilder.append(readBuffer.utf8ToString());
+            if (valueIndex != promptBlock.getValueCount(pos) - 1) {
+                strBuilder.append("\n");
+            }
+        }
+        return strBuilder.toString();
+    }
+
+    /**
+     * Returns the total number of positions (prompts) in the block.
+     */
+    public int estimatedSize() {
+        return promptBlock.getPositionCount();
+    }
+
+    @Override
+    public void close() {
+        promptBlock.allowPassingToDifferentDriver();
+        Releasables.close(promptBlock);
+    }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java
index 4b1cfe5870ad7..fe33c799aaec8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java
@@ -13,6 +13,7 @@
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestItem;
 import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
 
 import java.util.ArrayList;
@@ -46,7 +47,7 @@ public boolean hasNext() {
     }
 
     @Override
-    public InferenceAction.Request next() {
+    public BulkInferenceRequestItem next() {
         if (hasNext() == false) {
             throw new NoSuchElementException();
         }
@@ -59,7 +60,7 @@ public InferenceAction.Request next() {
 
         if (inputBlock.isNull(startIndex)) {
             remainingPositions -= 1;
-            return null;
+            return BulkInferenceRequestItem.from((InferenceAction.Request) null);
         }
 
         for (int i = 0; i < maxInputSize; i++) {
@@ -73,7 +74,7 @@ public InferenceAction.Request next() {
         }
 
         remainingPositions -= inputs.size();
-        return inferenceRequest(inputs);
+        return BulkInferenceRequestItem.from(inferenceRequest(inputs));
     }
 
     @Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
index 563bd105fc7cb..12950dcafbf6c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java
@@ -152,7 +152,20 @@ private Attribute renameTargetField(String newName) {
 
     @Override
     public Completion withResolvedInference(ResolvedInference resolvedInference) {
-        return super.withResolvedInference(resolvedInference);
+        Completion completion = super.withResolvedInference(resolvedInference);
+
+        if (completion.inferenceId().resolved()) {
+            return new Completion(
+                source(),
+                child(),
+                completion.inferenceId(),
+                resolvedInference.taskType(),
+                completion.prompt(),
+                completion.targetField()
+            );
+        }
+
+        return completion;
     }
 
     @Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
index de8c673728056..f973fcd58e346 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
@@ -23,8 +23,8 @@ public abstract class InferenceExec extends UnaryExec {
 
     protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId, TaskType taskType) {
         super(source, child);
-        this.inferenceId = inferenceId;
-        this.taskType = taskType;
+        this.inferenceId = Objects.requireNonNull(inferenceId);
+        this.taskType = Objects.requireNonNull(taskType);
     }
 
     public Expression inferenceId() {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
index 878d223535df5..29fee28801aa8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
@@ -89,6 +89,7 @@
 import org.elasticsearch.xpack.esql.expression.Order;
 import org.elasticsearch.xpack.esql.inference.InferenceService;
 import org.elasticsearch.xpack.esql.inference.XContentRowEncoder;
+import org.elasticsearch.xpack.esql.inference.completion.ChatCompletionOperator;
 import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator;
 import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator;
 import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
@@ -323,7 +324,12 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti
             source.layout
         );
 
-        return source.with(new CompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory), outputLayout);
+        OperatorFactory operatorFactory = switch (completion.taskType()) {
+            case CHAT_COMPLETION -> new ChatCompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory);
+            default -> new CompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory);
+        };
+
+        return source.with(operatorFactory, outputLayout);
     }
 
     private PhysicalOperation planFuseScoreEvalExec(FuseScoreEvalExec fuse, LocalExecutionPlannerContext context) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java
index dedbf895860b9..02eec173eb3cd 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java
@@ -193,7 +193,10 @@ private BulkInferenceRunnerConfig randomBulkExecutionConfig() {
     }
 
     private BulkInferenceRequestIterator requestIterator(List requests) {
-        final Iterator delegate = requests.iterator();
+        final Iterator> delegate = requests.stream()
+            .map(BulkInferenceRequestItem::from)
+            .toList()
+            .iterator();
         BulkInferenceRequestIterator iterator = mock(BulkInferenceRequestIterator.class);
         doAnswer(i -> delegate.hasNext()).when(iterator).hasNext();
         doAnswer(i -> delegate.next()).when(iterator).next();
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java
index 86592256d26bc..63b9a7677911c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java
@@ -32,7 +32,7 @@ private void assertIterate(int size) throws Exception {
             BytesRef scratch = new BytesRef();
 
             for (int currentPos = 0; requestIterator.hasNext(); currentPos++) {
-                InferenceAction.Request request = requestIterator.next();
+                InferenceAction.Request request = requestIterator.next().inferenceRequest();
                 assertThat(request.getInferenceEntityId(), equalTo(inferenceId));
                 scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch);
                 assertThat(request.getInput().getFirst(), equalTo(scratch.utf8ToString()));
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java
index 72397efcf1be3..010a1ed35816a 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java
@@ -37,7 +37,7 @@ private void assertIterate(int size, int batchSize) throws Exception {
             BytesRef scratch = new BytesRef();
 
             for (int currentPos = 0; requestIterator.hasNext();) {
-                InferenceAction.Request request = requestIterator.next();
+                InferenceAction.Request request = requestIterator.next().inferenceRequest();
 
                 assertThat(request.getInferenceEntityId(), equalTo(inferenceId));
                 assertThat(request.getQuery(), equalTo(queryText));
From c7a456abf0d81a630ae15f63ff16167d4d4592cb Mon Sep 17 00:00:00 2001
From: afoucret 
Date: Sat, 13 Sep 2025 06:32:42 +0200
Subject: [PATCH 4/4] Checkstyle fix
---
 .../xpack/esql/inference/bulk/BulkInferenceRequestItem.java     | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java
index d2a9c51182968..20b1bc24143da 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestItem.java
@@ -14,7 +14,7 @@
 
 import java.util.Objects;
 
-sealed public interface BulkInferenceRequestItem permits
+public sealed interface BulkInferenceRequestItem permits
     BulkInferenceRequestItem.AbstractBulkInferenceRequestItem {
 
     TaskType taskType();