77
88package org .elasticsearch .xpack .esql .inference ;
99
10+ import org .apache .lucene .util .BytesRef ;
11+ import org .apache .lucene .util .BytesRefBuilder ;
1012import org .elasticsearch .action .ActionListener ;
13+ import org .elasticsearch .action .support .CountDownActionListener ;
14+ import org .elasticsearch .compute .data .BlockFactory ;
15+ import org .elasticsearch .compute .data .BytesRefBlock ;
1116import org .elasticsearch .compute .data .Page ;
1217import org .elasticsearch .compute .operator .DriverContext ;
1318import org .elasticsearch .compute .operator .EvalOperator .ExpressionEvaluator ;
1419import org .elasticsearch .compute .operator .Operator ;
20+ import org .elasticsearch .inference .TaskType ;
21+ import org .elasticsearch .logging .LogManager ;
22+ import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
23+ import org .elasticsearch .xpack .core .inference .results .ChatCompletionResults ;
24+
25+ import java .util .List ;
1526
1627public class CompletionOperator extends InferenceOperator <Page > {
1728
@@ -30,6 +41,7 @@ public Operator get(DriverContext driverContext) {
3041 }
3142
3243 private final ExpressionEvaluator promptEvaluator ;
44+ private final BlockFactory blockFactory ;
3345
3446 public CompletionOperator (
3547 DriverContext driverContext ,
@@ -39,12 +51,66 @@ public CompletionOperator(
3951 ) {
4052 super (driverContext , inferenceRunner .getThreadContext (), inferenceRunner , inferenceId );
4153 this .promptEvaluator = promptEvaluator ;
54+ this .blockFactory = driverContext .blockFactory ();
4255 }
4356
4457 @ Override
4558 protected void performAsync (Page inputPage , ActionListener <Page > listener ) {
46- Page outputPage = inputPage .appendBlock (promptEvaluator .eval (inputPage ));
47- listener .onResponse (outputPage );
59+ int pageSize = inputPage .getPositionCount ();
60+ String [] responses = new String [pageSize ];
61+
62+ CountDownActionListener countDownListener = new CountDownActionListener (
63+ inputPage .getPositionCount (),
64+ listener .delegateFailureIgnoreResponseAndWrap (l -> {
65+ try (BytesRefBlock .Builder outputBlockBuilder = blockFactory .newBytesRefBlockBuilder (pageSize )) {
66+ BytesRefBuilder bytesRefBuilder = new BytesRefBuilder ();
67+ for (int pos = 0 ; pos < pageSize ; pos ++) {
68+ if (responses [pos ] == null ) {
69+ outputBlockBuilder .appendNull ();
70+ } else {
71+ bytesRefBuilder .copyChars (responses [pos ]);
72+ outputBlockBuilder .appendBytesRef (bytesRefBuilder .get ());
73+ }
74+ }
75+
76+ l .onResponse (inputPage .appendBlock (outputBlockBuilder .build ()));
77+ }
78+ })
79+ );
80+
81+ try (BytesRefBlock promptBlock = (BytesRefBlock ) promptEvaluator .eval (inputPage )) {
82+ BytesRef readBuffer = new BytesRef ();
83+ for (int pos = 0 ; pos < pageSize ; pos ++) {
84+ final int currentPos = pos ;
85+ if (promptBlock .isNull (pos )) {
86+ countDownListener .onResponse (null );
87+ } else {
88+ StringBuilder promptBuilder = new StringBuilder ();
89+ for (int valueIndex = 0 ; valueIndex < promptBlock .getValueCount (pos ); valueIndex ++) {
90+ readBuffer = promptBlock .getBytesRef (promptBlock .getFirstValueIndex (pos ) + valueIndex , readBuffer );
91+ promptBuilder .append (readBuffer .utf8ToString ()).append ("\n " );
92+
93+
94+ InferenceAction .Request request = InferenceAction .Request .builder (inferenceId (), TaskType .COMPLETION )
95+ .setInput (List .of (promptBuilder .toString ())).build ();
96+
97+ doInference (request , countDownListener .delegateFailureAndWrap ((l , r ) -> {
98+ if (r .getResults () instanceof ChatCompletionResults completionResults ) {
99+ responses [currentPos ] = completionResults .results ().getFirst ().content ();
100+ l .onResponse (null );
101+ } else {
102+ l .onFailure (new IllegalStateException (
103+ "Inference result has wrong type. Got ["
104+ + r .getResults ().getClass ()
105+ + "] while expecting ["
106+ + ChatCompletionResults .class
107+ + "]"
108+ ));
109+ }
110+ }));
111+ }
112+ }
113+ }
48114 }
49115
50116 @ Override
0 commit comments