2525
2626import java .util .List ;
2727import java .util .NoSuchElementException ;
28+ import java .util .concurrent .atomic .AtomicBoolean ;
2829
2930public class CompletionOperator extends InferenceOperator <ChatCompletionResults > {
3031
@@ -33,7 +34,7 @@ public record Factory(InferenceRunner inferenceRunner, String inferenceId, Expre
3334 OperatorFactory {
3435 @ Override
3536 public String describe () {
36- return "Completion [inference_id=[" + inferenceId + "]]" ;
37+ return "CompletionOperator [inference_id=[" + inferenceId + "]]" ;
3738 }
3839
3940 @ Override
@@ -73,82 +74,102 @@ public String toString() {
7374
7475 @ Override
7576 protected BulkInferenceRequestIterator requests (Page inputPage ) {
76- return new BulkInferenceRequestIterator () {
77- private final BytesRefBlock promptBlock = (BytesRefBlock ) promptEvaluator .eval (inputPage );
78- private BytesRef readBuffer = new BytesRef ();
79- private int currentPos = 0 ;
80-
81- @ Override
82- public boolean hasNext () {
83- return currentPos < promptBlock .getPositionCount ();
84- }
85-
86- @ Override
87- public InferenceAction .Request next () {
88- if (hasNext () == false ) {
89- throw new NoSuchElementException ();
77+ final BytesRefBlock promptBlock = (BytesRefBlock ) promptEvaluator .eval (inputPage );
78+ try {
79+ return new BulkInferenceRequestIterator () {
80+ private int currentPos = 0 ;
81+ BytesRef readBuffer = new BytesRef ();
82+ @ Override
83+ public boolean hasNext () {
84+ return currentPos < promptBlock .getPositionCount ();
9085 }
91- int pos = currentPos ++;
9286
93- if (promptBlock .isNull (pos )) {
94- return null ;
95- }
87+ @ Override
88+ public InferenceAction .Request next () {
89+ if (hasNext () == false ) {
90+ throw new NoSuchElementException ();
91+ }
92+ int pos = currentPos ++;
9693
97- StringBuilder promptBuilder = new StringBuilder ();
98- for (int valueIndex = 0 ; valueIndex < promptBlock .getValueCount (pos ); valueIndex ++) {
99- readBuffer = promptBlock .getBytesRef (promptBlock .getFirstValueIndex (pos ) + valueIndex , readBuffer );
100- promptBuilder .append (readBuffer .utf8ToString ()).append ("\n " );
101- }
94+ if (promptBlock .isNull (pos )) {
95+ return null ;
96+ }
97+
98+ StringBuilder promptBuilder = new StringBuilder ();
99+ for (int valueIndex = 0 ; valueIndex < promptBlock .getValueCount (pos ); valueIndex ++) {
100+ readBuffer = promptBlock .getBytesRef (promptBlock .getFirstValueIndex (pos ) + valueIndex , readBuffer );
101+ promptBuilder .append (readBuffer .utf8ToString ()).append ("\n " );
102+ }
102103
103- return inferenceRequest (promptBuilder .toString ());
104- }
104+ return inferenceRequest (promptBuilder .toString ());
105+ }
105106
106- @ Override
107- public void close () {
108- promptBlock .allowPassingToDifferentDriver ();
109- Releasables .closeExpectNoException (promptBlock );
110- }
111- };
107+ @ Override
108+ public void close () {
109+ promptBlock .allowPassingToDifferentDriver ();
110+ Releasables .closeExpectNoException (promptBlock );
111+ }
112+ };
113+ } catch (Exception e ) {
114+ promptBlock .allowPassingToDifferentDriver ();
115+ Releasables .closeExpectNoException (promptBlock );
116+ throw (e );
117+ }
112118 }
113119
114120 @ Override
115121 protected BulkInferenceOutputBuilder <ChatCompletionResults , Page > outputBuilder (Page inputPage ) {
116- return new BulkInferenceOutputBuilder <>() {
117- private final BytesRefBlock .Builder outputBlockBuilder = blockFactory ().newBytesRefBlockBuilder (inputPage .getPositionCount ());
118- private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder ();
119-
120- @ Override
121- public void close () {
122- Releasables .closeExpectNoException (outputBlockBuilder );
123- }
124-
125- @ Override
126- public void onInferenceResults (ChatCompletionResults completionResults ) {
127- if (completionResults == null || completionResults .getResults ().isEmpty ()) {
128- outputBlockBuilder .appendNull ();
129- } else {
130- outputBlockBuilder .beginPositionEntry ();
131- for (ChatCompletionResults .Result rankedDocsResult : completionResults .getResults ()) {
132- bytesRefBuilder .copyChars (rankedDocsResult .content ());
133- outputBlockBuilder .appendBytesRef (bytesRefBuilder .get ());
134- bytesRefBuilder .clear ();
122+ final BytesRefBlock .Builder outputBlockBuilder = blockFactory ().newBytesRefBlockBuilder (inputPage .getPositionCount ());
123+ final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder ();
124+ final AtomicBoolean isOutputBuilt = new AtomicBoolean (false );
125+
126+ try {
127+ return new BulkInferenceOutputBuilder <>() {
128+ @ Override
129+ public void close () {
130+ if (isOutputBuilt .get () == false ) {
131+ releasePageOnAnyThread (inputPage );
132+ }
133+
134+ Releasables .closeExpectNoException (outputBlockBuilder );
135+ }
136+
137+ @ Override
138+ public void onInferenceResults (ChatCompletionResults completionResults ) {
139+ if (completionResults == null || completionResults .getResults ().isEmpty ()) {
140+ outputBlockBuilder .appendNull ();
141+ } else {
142+ outputBlockBuilder .beginPositionEntry ();
143+ for (ChatCompletionResults .Result rankedDocsResult : completionResults .getResults ()) {
144+ bytesRefBuilder .copyChars (rankedDocsResult .content ());
145+ outputBlockBuilder .appendBytesRef (bytesRefBuilder .get ());
146+ bytesRefBuilder .clear ();
147+ }
148+ outputBlockBuilder .endPositionEntry ();
149+ }
150+ }
151+
152+ @ Override
153+ protected Class <ChatCompletionResults > inferenceResultsClass () {
154+ return ChatCompletionResults .class ;
155+ }
156+
157+ @ Override
158+ public Page buildOutput () {
159+ if (isOutputBuilt .compareAndSet (false , true )) {
160+ Block outputBlock = outputBlockBuilder .build ();
161+ assert outputBlock .getPositionCount () == inputPage .getPositionCount ();
162+ return inputPage .appendBlock (outputBlock );
135163 }
136- outputBlockBuilder .endPositionEntry ();
164+
165+ throw new IllegalStateException ("buildOutput has already been called" );
137166 }
138- }
139-
140- @ Override
141- protected Class <ChatCompletionResults > inferenceResultsClass () {
142- return ChatCompletionResults .class ;
143- }
144-
145- @ Override
146- public Page buildOutput () {
147- Block outputBlock = outputBlockBuilder .build ();
148- assert outputBlock .getPositionCount () == inputPage .getPositionCount ();
149- return inputPage .appendBlock (outputBlock );
150- }
151- };
167+ };
168+ } catch (Exception e ) {
169+ releasePageOnAnyThread (inputPage );
170+ Releasables .closeExpectNoException (outputBlockBuilder );
171+ throw (e );
172+ }
152173 }
153174
154175 private InferenceAction .Request inferenceRequest (String prompt ) {
0 commit comments