Skip to content

Commit d697c46

Browse files
committed
Draft CompletionOperator.
1 parent 5f5ffdc commit d697c46

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/CompletionOperator.java

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,22 @@
77

88
package org.elasticsearch.xpack.esql.inference;
99

10+
import org.apache.lucene.util.BytesRef;
11+
import org.apache.lucene.util.BytesRefBuilder;
1012
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.action.support.CountDownActionListener;
14+
import org.elasticsearch.compute.data.BlockFactory;
15+
import org.elasticsearch.compute.data.BytesRefBlock;
1116
import org.elasticsearch.compute.data.Page;
1217
import org.elasticsearch.compute.operator.DriverContext;
1318
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
1419
import org.elasticsearch.compute.operator.Operator;
20+
import org.elasticsearch.inference.TaskType;
21+
import org.elasticsearch.logging.LogManager;
22+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
23+
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
24+
25+
import java.util.List;
1526

1627
public class CompletionOperator extends InferenceOperator<Page> {
1728

@@ -30,6 +41,7 @@ public Operator get(DriverContext driverContext) {
3041
}
3142

3243
private final ExpressionEvaluator promptEvaluator;
44+
private final BlockFactory blockFactory;
3345

3446
public CompletionOperator(
3547
DriverContext driverContext,
@@ -39,12 +51,66 @@ public CompletionOperator(
3951
) {
4052
super(driverContext, inferenceRunner.getThreadContext(), inferenceRunner, inferenceId);
4153
this.promptEvaluator = promptEvaluator;
54+
this.blockFactory = driverContext.blockFactory();
4255
}
4356

4457
@Override
4558
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
46-
Page outputPage = inputPage.appendBlock(promptEvaluator.eval(inputPage));
47-
listener.onResponse(outputPage);
59+
int pageSize = inputPage.getPositionCount();
60+
String[] responses = new String[pageSize];
61+
62+
CountDownActionListener countDownListener = new CountDownActionListener(
63+
inputPage.getPositionCount(),
64+
listener.delegateFailureIgnoreResponseAndWrap(l -> {
65+
try(BytesRefBlock.Builder outputBlockBuilder = blockFactory.newBytesRefBlockBuilder(pageSize)) {
66+
BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
67+
for (int pos = 0; pos < pageSize; pos++) {
68+
if (responses[pos] == null) {
69+
outputBlockBuilder.appendNull();
70+
} else {
71+
bytesRefBuilder.copyChars(responses[pos]);
72+
outputBlockBuilder.appendBytesRef(bytesRefBuilder.get());
73+
}
74+
}
75+
76+
l.onResponse(inputPage.appendBlock(outputBlockBuilder.build()));
77+
}
78+
})
79+
);
80+
81+
try (BytesRefBlock promptBlock = (BytesRefBlock) promptEvaluator.eval(inputPage)) {
82+
BytesRef readBuffer = new BytesRef();
83+
for (int pos = 0; pos < pageSize; pos++) {
84+
final int currentPos = pos;
85+
if (promptBlock.isNull(pos)) {
86+
countDownListener.onResponse(null);
87+
} else {
88+
StringBuilder promptBuilder = new StringBuilder();
89+
for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
90+
readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
91+
promptBuilder.append(readBuffer.utf8ToString()).append("\n");
92+
93+
94+
InferenceAction.Request request = InferenceAction.Request.builder(inferenceId(), TaskType.COMPLETION)
95+
.setInput(List.of(promptBuilder.toString())).build();
96+
97+
doInference(request, countDownListener.delegateFailureAndWrap((l, r) -> {
98+
if (r.getResults() instanceof ChatCompletionResults completionResults) {
99+
responses[currentPos] = completionResults.results().getFirst().content();
100+
l.onResponse(null);
101+
} else {
102+
l.onFailure(new IllegalStateException(
103+
"Inference result has wrong type. Got ["
104+
+ r.getResults().getClass()
105+
+ "] while expecting ["
106+
+ ChatCompletionResults.class
107+
+ "]"
108+
));
109+
}
110+
}));
111+
}
112+
}
113+
}
48114
}
49115

50116
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ protected AttributeSet computeReferences() {
136136

137137
@Override
138138
public boolean expressionsResolved() {
139-
return super.expressionsResolved() && prompt.resolved();
139+
return super.expressionsResolved() && prompt.resolved() && targetField.resolved();
140140
}
141141

142142
@Override

0 commit comments

Comments
 (0)