Skip to content

Commit 28cdd88

Browse files
committed
CompletionOperator skeleton.
1 parent db55fec commit 28cdd88

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.compute.data.Page;
12+
import org.elasticsearch.compute.operator.DriverContext;
13+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
14+
import org.elasticsearch.compute.operator.Operator;
15+
16+
public class CompletionOperator extends InferenceOperator<Page> {
17+
18+
public record Factory(InferenceRunner inferenceRunner, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory)
19+
implements
20+
OperatorFactory {
21+
@Override
22+
public String describe() {
23+
return "RerankOperator[inference_id=[" + inferenceId + "]]";
24+
}
25+
26+
@Override
27+
public Operator get(DriverContext driverContext) {
28+
return new CompletionOperator(driverContext, inferenceRunner, inferenceId, promptEvaluatorFactory.get(driverContext));
29+
}
30+
}
31+
32+
private final ExpressionEvaluator promptEvaluator;
33+
34+
public CompletionOperator(
35+
DriverContext driverContext,
36+
InferenceRunner inferenceRunner,
37+
String inferenceId,
38+
ExpressionEvaluator promptEvaluator
39+
) {
40+
super(driverContext, inferenceRunner.getThreadContext(), inferenceRunner, inferenceId);
41+
this.promptEvaluator = promptEvaluator;
42+
}
43+
44+
@Override
45+
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
46+
Page outputPage = inputPage.appendBlock(promptEvaluator.eval(inputPage));
47+
listener.onResponse(outputPage);
48+
}
49+
50+
@Override
51+
protected void doClose() {
52+
53+
}
54+
55+
@Override
56+
protected void releaseFetchedOnAnyThread(Page page) {
57+
releasePageOnAnyThread(page);
58+
}
59+
60+
@Override
61+
public Page getOutput() {
62+
return fetchFromBuffer();
63+
}
64+
65+
@Override
66+
public String toString() {
67+
return "CompletionOperator[inference_id=[" + inferenceId() + "]]";
68+
}
69+
70+
}

0 commit comments

Comments
 (0)