1919import org .elasticsearch .compute .operator .DriverContext ;
2020import org .elasticsearch .compute .operator .EvalOperator .ExpressionEvaluator ;
2121import org .elasticsearch .compute .operator .Operator ;
22+ import org .elasticsearch .core .Releasable ;
2223import org .elasticsearch .core .Releasables ;
2324import org .elasticsearch .inference .TaskType ;
2425import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2526import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
2627
2728import java .util .List ;
2829
29- public class RerankOperator extends AsyncOperator <Page > {
30+ public class RerankOperator extends AsyncOperator <RerankOperator . InputPageAndRankedScores > {
3031
3132 // Move to a setting.
3233 private static final int MAX_INFERENCE_WORKER = 10 ;
@@ -85,20 +86,30 @@ public RerankOperator(
8586 }
8687
8788 @ Override
88- protected void performAsync (Page inputPage , ActionListener <Page > listener ) {
89+ protected void performAsync (Page inputPage , ActionListener <InputPageAndRankedScores > listener ) {
8990 // Ensure input page blocks are released when the listener is called.
90- final ActionListener <Page > outputListener = ActionListener .runAfter (listener , () -> { releasePageOnAnyThread (inputPage ); });
91-
91+ listener = listener .delegateResponse ((l , e ) -> {
92+ releasePageOnAnyThread (inputPage );
93+ l .onFailure (e );
94+ });
9295 try {
93- inferenceRunner .doInference (
94- buildInferenceRequest (inputPage ),
95- ActionListener .wrap (
96- inferenceResponse -> outputListener .onResponse (buildOutput (inputPage , inferenceResponse )),
97- outputListener ::onFailure
98- )
99- );
96+ inferenceRunner .doInference (buildInferenceRequest (inputPage ), listener .map (resp -> {
97+ if (resp .getResults () instanceof RankedDocsResults == false ) {
98+ releasePageOnAnyThread (inputPage );
99+ throw new IllegalStateException (
100+ "Inference result has wrong type. Got ["
101+ + resp .getResults ().getClass ()
102+ + "] while expecting ["
103+ + RankedDocsResults .class
104+ + "]"
105+ );
106+
107+ }
108+ final var results = (RankedDocsResults ) resp .getResults ();
109+ return new InputPageAndRankedScores (inputPage , extractRankedScores (inputPage .getPositionCount (), results ));
110+ }));
100111 } catch (Exception e ) {
101- outputListener .onFailure (e );
112+ listener .onFailure (e );
102113 }
103114 }
104115
@@ -108,71 +119,63 @@ protected void doClose() {
108119 }
109120
110121 @ Override
111- protected void releaseFetchedOnAnyThread (Page page ) {
112- releasePageOnAnyThread (page );
122+ protected void releaseFetchedOnAnyThread (InputPageAndRankedScores result ) {
123+ releasePageOnAnyThread (result . inputPage () );
113124 }
114125
115126 @ Override
116127 public Page getOutput () {
117- return fetchFromBuffer ();
128+ var fetched = fetchFromBuffer ();
129+ if (fetched == null ) {
130+ return null ;
131+ }
132+ return buildOutput (fetched .inputPage (), fetched .rankedScores ());
118133 }
119134
120135 @ Override
121136 public String toString () {
122137 return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]" ;
123138 }
124139
125- private Page buildOutput (Page inputPage , InferenceAction .Response inferenceResponse ) {
126- if (inferenceResponse .getResults () instanceof RankedDocsResults rankedDocsResults ) {
127- return buildOutput (inputPage , rankedDocsResults );
128-
129- }
130-
131- throw new IllegalStateException (
132- "Inference result has wrong type. Got ["
133- + inferenceResponse .getResults ().getClass ()
134- + "] while expecting ["
135- + RankedDocsResults .class
136- + "]"
137- );
138- }
139-
140- private Page buildOutput (Page inputPage , RankedDocsResults rankedDocsResults ) {
140+ private Page buildOutput (Page inputPage , Double [] rankedScores ) {
141141 int blockCount = Integer .max (inputPage .getBlockCount (), scoreChannel + 1 );
142142 Block [] blocks = new Block [blockCount ];
143-
144- try {
143+ Page outputPage = null ;
144+ try ( Releasable ignored = inputPage :: releaseBlocks ) {
145145 for (int b = 0 ; b < blockCount ; b ++) {
146146 if (b == scoreChannel ) {
147- blocks [b ] = buildScoreBlock (inputPage , rankedDocsResults );
147+ blocks [b ] = buildScoreBlock (rankedScores );
148148 } else {
149149 blocks [b ] = inputPage .getBlock (b );
150150 blocks [b ].incRef ();
151151 }
152152 }
153- return new Page (blocks );
154- } catch (Exception e ) {
155- Releasables .closeExpectNoException (blocks );
156- throw (e );
153+ outputPage = new Page (blocks );
154+ return outputPage ;
155+ } finally {
156+ if (outputPage == null ) {
157+ Releasables .closeExpectNoException (blocks );
158+ }
157159 }
158160 }
159161
160- private Block buildScoreBlock ( Page inputPage , RankedDocsResults rankedDocsResults ) {
161- Double [] sortedRankedDocsScores = new Double [inputPage . getPositionCount () ];
162-
163- try ( DoubleBlock . Builder scoreBlockFactory = blockFactory . newDoubleBlockBuilder ( inputPage . getPositionCount ())) {
164- for ( RankedDocsResults . RankedDoc rankedDoc : rankedDocsResults . getRankedDocs ()) {
165- sortedRankedDocsScores [ rankedDoc . index ()] = ( double ) rankedDoc . relevanceScore () ;
166- }
162+ private Double [] extractRankedScores ( int positionCount , RankedDocsResults rankedDocsResults ) {
163+ Double [] sortedRankedDocsScores = new Double [positionCount ];
164+ for ( RankedDocsResults . RankedDoc rankedDoc : rankedDocsResults . getRankedDocs ()) {
165+ sortedRankedDocsScores [ rankedDoc . index ()] = ( double ) rankedDoc . relevanceScore ();
166+ }
167+ return sortedRankedDocsScores ;
168+ }
167169
168- for (int pos = 0 ; pos < inputPage .getPositionCount (); pos ++) {
169- if (sortedRankedDocsScores [pos ] != null ) {
170- scoreBlockFactory .appendDouble (sortedRankedDocsScores [pos ]);
170+ private Block buildScoreBlock (Double [] rankedScores ) {
171+ try (DoubleBlock .Builder scoreBlockFactory = blockFactory .newDoubleBlockBuilder (rankedScores .length )) {
172+ for (Double rankedScore : rankedScores ) {
173+ if (rankedScore != null ) {
174+ scoreBlockFactory .appendDouble (rankedScore );
171175 } else {
172176 scoreBlockFactory .appendNull ();
173177 }
174178 }
175-
176179 return scoreBlockFactory .build ();
177180 }
178181 }
@@ -195,4 +198,8 @@ private InferenceAction.Request buildInferenceRequest(Page inputPage) {
195198 return InferenceAction .Request .builder (inferenceId , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ();
196199 }
197200 }
201+
202+ public record InputPageAndRankedScores (Page inputPage , Double [] rankedScores ) {
203+
204+ }
198205}
0 commit comments