1919import org .elasticsearch .compute .operator .DriverContext ;
2020import org .elasticsearch .compute .operator .EvalOperator .ExpressionEvaluator ;
2121import org .elasticsearch .compute .operator .Operator ;
22+ import org .elasticsearch .core .Releasable ;
2223import org .elasticsearch .core .Releasables ;
2324import org .elasticsearch .inference .TaskType ;
2425import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2526import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
2627
2728import java .util .List ;
2829
29- public class RerankOperator extends AsyncOperator <Page > {
30+ public class RerankOperator extends AsyncOperator <RerankOperator . OngoingRerank > {
3031
3132 // Move to a setting.
3233 private static final int MAX_INFERENCE_WORKER = 10 ;
@@ -85,20 +86,16 @@ public RerankOperator(
8586 }
8687
8788 @ Override
88- protected void performAsync (Page inputPage , ActionListener <Page > listener ) {
89+ protected void performAsync (Page inputPage , ActionListener <OngoingRerank > listener ) {
8990 // Ensure input page blocks are released when the listener is called.
90- final ActionListener <Page > outputListener = ActionListener .runAfter (listener , () -> { releasePageOnAnyThread (inputPage ); });
91-
91+ listener = listener .delegateResponse ((l , e ) -> {
92+ releasePageOnAnyThread (inputPage );
93+ l .onFailure (e );
94+ });
9295 try {
93- inferenceRunner .doInference (
94- buildInferenceRequest (inputPage ),
95- ActionListener .wrap (
96- inferenceResponse -> outputListener .onResponse (buildOutput (inputPage , inferenceResponse )),
97- outputListener ::onFailure
98- )
99- );
96+ inferenceRunner .doInference (buildInferenceRequest (inputPage ), listener .map (resp -> new OngoingRerank (inputPage , resp )));
10097 } catch (Exception e ) {
101- outputListener .onFailure (e );
98+ listener .onFailure (e );
10299 }
103100 }
104101
@@ -108,91 +105,106 @@ protected void doClose() {
108105 }
109106
110107 @ Override
111- protected void releaseFetchedOnAnyThread (Page page ) {
112- releasePageOnAnyThread (page );
108+ protected void releaseFetchedOnAnyThread (OngoingRerank result ) {
109+ releasePageOnAnyThread (result . inputPage );
113110 }
114111
115112 @ Override
116113 public Page getOutput () {
117- return fetchFromBuffer ();
114+ var fetched = fetchFromBuffer ();
115+ if (fetched == null ) {
116+ return null ;
117+ } else {
118+ return fetched .buildOutput (blockFactory , scoreChannel );
119+ }
118120 }
119121
120122 @ Override
121123 public String toString () {
122124 return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]" ;
123125 }
124126
125- private Page buildOutput (Page inputPage , InferenceAction .Response inferenceResponse ) {
126- if (inferenceResponse .getResults () instanceof RankedDocsResults rankedDocsResults ) {
127- return buildOutput (inputPage , rankedDocsResults );
128-
129- }
130-
131- throw new IllegalStateException (
132- "Inference result has wrong type. Got ["
133- + inferenceResponse .getResults ().getClass ()
134- + "] while expecting ["
135- + RankedDocsResults .class
136- + "]"
137- );
138- }
139-
140- private Page buildOutput (Page inputPage , RankedDocsResults rankedDocsResults ) {
141- int blockCount = Integer .max (inputPage .getBlockCount (), scoreChannel + 1 );
142- Block [] blocks = new Block [blockCount ];
127+ private InferenceAction .Request buildInferenceRequest (Page inputPage ) {
128+ try (BytesRefBlock encodedRowsBlock = (BytesRefBlock ) rowEncoder .eval (inputPage )) {
129+ assert (encodedRowsBlock .getPositionCount () == inputPage .getPositionCount ());
130+ String [] inputs = new String [inputPage .getPositionCount ()];
131+ BytesRef buffer = new BytesRef ();
143132
144- try {
145- for (int b = 0 ; b < blockCount ; b ++) {
146- if (b == scoreChannel ) {
147- blocks [b ] = buildScoreBlock (inputPage , rankedDocsResults );
133+ for (int pos = 0 ; pos < inputPage .getPositionCount (); pos ++) {
134+ if (encodedRowsBlock .isNull (pos )) {
135+ inputs [pos ] = "" ;
148136 } else {
149- blocks [ b ] = inputPage . getBlock ( b );
150- blocks [ b ]. incRef ( );
137+ buffer = encodedRowsBlock . getBytesRef ( encodedRowsBlock . getFirstValueIndex ( pos ), buffer );
138+ inputs [ pos ] = BytesRefs . toString ( buffer );
151139 }
152140 }
153- return new Page (blocks );
154- } catch (Exception e ) {
155- Releasables .closeExpectNoException (blocks );
156- throw (e );
141+
142+ return InferenceAction .Request .builder (inferenceId , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ();
157143 }
158144 }
159145
160- private Block buildScoreBlock (Page inputPage , RankedDocsResults rankedDocsResults ) {
161- Double [] sortedRankedDocsScores = new Double [inputPage .getPositionCount ()];
146+ public static final class OngoingRerank {
147+ final Page inputPage ;
148+ final Double [] rankedScores ;
149+
150+ OngoingRerank (Page inputPage , InferenceAction .Response resp ) {
151+ if (resp .getResults () instanceof RankedDocsResults == false ) {
152+ releasePageOnAnyThread (inputPage );
153+ throw new IllegalStateException (
154+ "Inference result has wrong type. Got ["
155+ + resp .getResults ().getClass ()
156+ + "] while expecting ["
157+ + RankedDocsResults .class
158+ + "]"
159+ );
162160
163- try (DoubleBlock .Builder scoreBlockFactory = blockFactory .newDoubleBlockBuilder (inputPage .getPositionCount ())) {
161+ }
162+ final var results = (RankedDocsResults ) resp .getResults ();
163+ this .inputPage = inputPage ;
164+ this .rankedScores = extractRankedScores (inputPage .getPositionCount (), results );
165+ }
166+
167+ private static Double [] extractRankedScores (int positionCount , RankedDocsResults rankedDocsResults ) {
168+ Double [] sortedRankedDocsScores = new Double [positionCount ];
164169 for (RankedDocsResults .RankedDoc rankedDoc : rankedDocsResults .getRankedDocs ()) {
165170 sortedRankedDocsScores [rankedDoc .index ()] = (double ) rankedDoc .relevanceScore ();
166171 }
172+ return sortedRankedDocsScores ;
173+ }
167174
168- for (int pos = 0 ; pos < inputPage .getPositionCount (); pos ++) {
169- if (sortedRankedDocsScores [pos ] != null ) {
170- scoreBlockFactory .appendDouble (sortedRankedDocsScores [pos ]);
171- } else {
172- scoreBlockFactory .appendNull ();
175+ Page buildOutput (BlockFactory blockFactory , int scoreChannel ) {
176+ int blockCount = Integer .max (inputPage .getBlockCount (), scoreChannel + 1 );
177+ Block [] blocks = new Block [blockCount ];
178+ Page outputPage = null ;
179+ try (Releasable ignored = inputPage ::releaseBlocks ) {
180+ for (int b = 0 ; b < blockCount ; b ++) {
181+ if (b == scoreChannel ) {
182+ blocks [b ] = buildScoreBlock (blockFactory );
183+ } else {
184+ blocks [b ] = inputPage .getBlock (b );
185+ blocks [b ].incRef ();
186+ }
187+ }
188+ outputPage = new Page (blocks );
189+ return outputPage ;
190+ } finally {
191+ if (outputPage == null ) {
192+ Releasables .closeExpectNoException (blocks );
173193 }
174194 }
175-
176- return scoreBlockFactory .build ();
177195 }
178- }
179-
180- private InferenceAction .Request buildInferenceRequest (Page inputPage ) {
181- try (BytesRefBlock encodedRowsBlock = (BytesRefBlock ) rowEncoder .eval (inputPage )) {
182- assert (encodedRowsBlock .getPositionCount () == inputPage .getPositionCount ());
183- String [] inputs = new String [inputPage .getPositionCount ()];
184- BytesRef buffer = new BytesRef ();
185196
186- for (int pos = 0 ; pos < inputPage .getPositionCount (); pos ++) {
187- if (encodedRowsBlock .isNull (pos )) {
188- inputs [pos ] = "" ;
189- } else {
190- buffer = encodedRowsBlock .getBytesRef (encodedRowsBlock .getFirstValueIndex (pos ), buffer );
191- inputs [pos ] = BytesRefs .toString (buffer );
197+ private Block buildScoreBlock (BlockFactory blockFactory ) {
198+ try (DoubleBlock .Builder scoreBlockFactory = blockFactory .newDoubleBlockBuilder (rankedScores .length )) {
199+ for (Double rankedScore : rankedScores ) {
200+ if (rankedScore != null ) {
201+ scoreBlockFactory .appendDouble (rankedScore );
202+ } else {
203+ scoreBlockFactory .appendNull ();
204+ }
192205 }
206+ return scoreBlockFactory .build ();
193207 }
194-
195- return InferenceAction .Request .builder (inferenceId , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ();
196208 }
197209 }
198210}
0 commit comments