1717import org .elasticsearch .compute .data .Page ;
1818import org .elasticsearch .compute .operator .AsyncOperator ;
1919import org .elasticsearch .compute .operator .DriverContext ;
20- import org .elasticsearch .compute .operator .EvalOperator ;
20+ import org .elasticsearch .compute .operator .EvalOperator . ExpressionEvaluator ;
2121import org .elasticsearch .compute .operator .Operator ;
2222import org .elasticsearch .core .Releasables ;
2323import org .elasticsearch .inference .TaskType ;
2727import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
2828
2929import java .io .IOException ;
30+ import java .io .UncheckedIOException ;
3031import java .util .List ;
3132import java .util .Map ;
3233
@@ -39,7 +40,7 @@ public record Factory(
3940 InferenceService inferenceService ,
4041 String inferenceId ,
4142 String queryText ,
42- Map <String , EvalOperator . ExpressionEvaluator .Factory > rerankFieldsEvaluatorFactories ,
43+ Map <String , ExpressionEvaluator .Factory > fieldsEvaluatorFactories ,
4344 int scoreChannel
4445 ) implements OperatorFactory {
4546
@@ -50,7 +51,7 @@ public String describe() {
5051 + " query="
5152 + queryText
5253 + " rerank_fields="
53- + rerankFieldsEvaluatorFactories .keySet ()
54+ + fieldsEvaluatorFactories .keySet ()
5455 + " score_channel="
5556 + scoreChannel
5657 + "]" ;
@@ -63,67 +64,73 @@ public Operator get(DriverContext driverContext) {
6364 inferenceService ,
6465 inferenceId ,
6566 queryText ,
66- rerankFieldsEvaluatorFactories .keySet ().toArray (new String [0 ]),
67- rerankFieldsEvaluatorFactories .values ()
68- .stream ()
69- .map (factory -> factory .get (driverContext ))
70- .toArray (EvalOperator .ExpressionEvaluator []::new ),
67+ fieldNames (),
68+ fieldsEvaluators (driverContext ),
7169 scoreChannel
7270 );
7371 }
72+
73+ private String [] fieldNames () {
74+ return fieldsEvaluatorFactories .keySet ().toArray (String []::new );
75+ }
76+
77+ private ExpressionEvaluator [] fieldsEvaluators (DriverContext context ) {
78+ return fieldsEvaluatorFactories .values ().stream ().map (factory -> factory .get (context )).toArray (ExpressionEvaluator []::new );
79+ }
7480 }
7581
7682 private final InferenceService inferenceService ;
7783 private final BlockFactory blockFactory ;
7884 private final String inferenceId ;
7985 private final String queryText ;
80- private final String [] rerankFieldNames ;
81- private final EvalOperator . ExpressionEvaluator [] rerankFieldsEvaluators ;
86+ private final String [] fieldNames ;
87+ private final ExpressionEvaluator [] fieldsEvaluators ;
8288 private final int scoreChannel ;
8389
8490 public RerankOperator (
8591 DriverContext driverContext ,
8692 InferenceService inferenceService ,
8793 String inferenceId ,
8894 String queryText ,
89- String [] rerankFieldNames ,
90- EvalOperator . ExpressionEvaluator [] rerankFieldsEvaluators ,
95+ String [] fieldNames ,
96+ ExpressionEvaluator [] fieldsEvaluators ,
9197 int scoreChannel
9298 ) {
9399 super (driverContext , inferenceService .getThreadContext (), MAX_INFERENCE_WORKER );
94100
95101 assert inferenceService .getThreadContext () != null ;
102+ assert fieldNames .length == fieldsEvaluators .length ;
96103
97104 this .blockFactory = driverContext .blockFactory ();
98105 this .inferenceService = inferenceService ;
99106 this .inferenceId = inferenceId ;
100107 this .queryText = queryText ;
101- this .rerankFieldNames = rerankFieldNames ;
102- this .rerankFieldsEvaluators = rerankFieldsEvaluators ;
108+ this .fieldNames = fieldNames ;
109+ this .fieldsEvaluators = fieldsEvaluators ;
103110 this .scoreChannel = scoreChannel ;
104111 }
105112
106113 @ Override
107114 protected void performAsync (Page inputPage , ActionListener <Page > listener ) {
108115 // Ensure input page blocks are released when the listener is called.
109- final ActionListener <Page > outputListener = ActionListener .runAfter (listener , inputPage :: releaseBlocks );
116+ final ActionListener <Page > outputListener = ActionListener .runAfter (listener , () -> { inputPage . releaseBlocks (); } );
110117
111- final ActionListener < InferenceAction . Response > inferenceResonseListener = ActionListener . wrap (
112- inferenceResponse -> buildOutput ( inputPage , inferenceResponse , outputListener ),
113- outputListener :: onFailure
114- );
115-
116- final ActionListener < InferenceAction . Request > buildInferenceRequestListener = ActionListener . wrap (
117- ( inferenceRequest ) -> inferenceService . doInference ( inferenceRequest , inferenceResonseListener ),
118- outputListener :: onFailure
119- );
120-
121- buildInferenceRequest ( inputPage , buildInferenceRequestListener );
118+ try {
119+ inferenceService . doInference (
120+ buildInferenceRequest ( inputPage ),
121+ ActionListener . wrap (
122+ inferenceResponse -> outputListener . onResponse ( buildOutput ( inputPage , inferenceResponse )),
123+ outputListener :: onFailure
124+ )
125+ );
126+ } catch ( Exception e ) {
127+ outputListener . onFailure ( e );
128+ }
122129 }
123130
124131 @ Override
125132 protected void doClose () {
126- Releasables .closeExpectNoException (this .rerankFieldsEvaluators );
133+ Releasables .closeExpectNoException (this .fieldsEvaluators );
127134 }
128135
129136 @ Override
@@ -143,30 +150,28 @@ public String toString() {
143150 + " query="
144151 + queryText
145152 + " rerank_fields="
146- + List .of (rerankFieldNames )
153+ + List .of (fieldNames )
147154 + " score_channel="
148155 + scoreChannel
149156 + "]" ;
150157 }
151158
152- private void buildOutput (Page inputPage , InferenceAction .Response inferenceResponse , ActionListener < Page > listener ) {
159+ private Page buildOutput (Page inputPage , InferenceAction .Response inferenceResponse ) {
153160 if (inferenceResponse .getResults () instanceof RankedDocsResults rankedDocsResults ) {
154- buildOutput (inputPage , rankedDocsResults , listener );
155- return ;
161+ return buildOutput (inputPage , rankedDocsResults );
162+
156163 }
157164
158- listener .onFailure (
159- new IllegalStateException (
160- "Inference result has wrong type. Got ["
161- + inferenceResponse .getResults ().getClass ()
162- + "] while expecting ["
163- + RankedDocsResults .class
164- + "]"
165- )
165+ throw new IllegalStateException (
166+ "Inference result has wrong type. Got ["
167+ + inferenceResponse .getResults ().getClass ()
168+ + "] while expecting ["
169+ + RankedDocsResults .class
170+ + "]"
166171 );
167172 }
168173
169- private void buildOutput (Page inputPage , RankedDocsResults rankedDocsResults , ActionListener < Page > listener ) {
174+ private Page buildOutput (Page inputPage , RankedDocsResults rankedDocsResults ) {
170175 int blockCount = inputPage .getBlockCount ();
171176 Block [] blocks = new Block [blockCount ];
172177
@@ -179,9 +184,10 @@ private void buildOutput(Page inputPage, RankedDocsResults rankedDocsResults, Ac
179184 blocks [b ].incRef ();
180185 }
181186 }
182- listener . onResponse ( new Page (blocks ) );
187+ return new Page (blocks );
183188 } catch (Exception e ) {
184189 Releasables .closeExpectNoException (blocks );
190+ throw (e );
185191 }
186192 }
187193
@@ -205,51 +211,34 @@ private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResult
205211 }
206212 }
207213
208- private void buildInferenceRequest (Page inputPage , ActionListener <InferenceAction .Request > listener ) {
209-
210- Block [] inputBlocks = inputBlocks (inputPage );
214+ private InferenceAction .Request buildInferenceRequest (Page inputPage ) {
215+ Block [] inputBlocks = new Block [fieldsEvaluators .length ];
211216
212217 try {
213- String [] inputs = new String [inputPage .getPositionCount ()];
214- if (inputBlocks .length > 0 ) for (int pos = 0 ; pos < inputPage .getPositionCount (); pos ++) {
215- inputs [pos ] = toYaml (inputBlocks , pos );
218+ for (int b = 0 ; b < inputBlocks .length ; b ++) {
219+ inputBlocks [b ] = fieldsEvaluators [b ].eval (inputPage );
216220 }
217- listener .onResponse (
218- InferenceAction .Request .builder (inferenceId , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ()
219- );
220- } catch (Exception e ) {
221- listener .onFailure (e );
222- } finally {
223- Releasables .closeExpectNoException (inputBlocks );
224- }
225- }
226221
227- private Block [] inputBlocks (Page inputPage ) {
228- Block [] blocks = new Block [rerankFieldsEvaluators .length ];
229-
230- try {
231- for (int i = 0 ; i < rerankFieldsEvaluators .length ; i ++) {
232- blocks [i ] = rerankFieldsEvaluators [i ].eval (inputPage );
233- }
234-
235- return blocks ;
236- } catch (Exception e ) {
237- Releasables .closeExpectNoException (blocks );
238- throw e ;
239- }
240- }
241-
242- private String toYaml (Block [] inputBlocks , int position ) throws IOException {
243- try (XContentBuilder yamlBuilder = XContentFactory .yamlBuilder ().startObject ()) {
244- for (int i = 0 ; i < inputBlocks .length ; i ++) {
245- String fieldName = rerankFieldNames [i ];
246- Block currentBlock = inputBlocks [i ];
247- if (currentBlock .isNull (position )) {
248- continue ;
222+ String [] inputs = new String [inputPage .getPositionCount ()];
223+ for (int pos = 0 ; pos < inputPage .getPositionCount (); pos ++) {
224+ try (XContentBuilder yamlBuilder = XContentFactory .yamlBuilder ().startObject ()) {
225+ for (int i = 0 ; i < inputBlocks .length ; i ++) {
226+ String fieldName = fieldNames [i ];
227+ Block currentBlock = inputBlocks [i ];
228+ if (currentBlock .isNull (pos )) {
229+ continue ;
230+ }
231+ yamlBuilder .field (fieldName , toYaml (BlockUtils .toJavaObject (currentBlock , pos )));
232+ }
233+ inputs [pos ] = Strings .toString (yamlBuilder .endObject ());
234+ } catch (IOException e ) {
235+ throw new UncheckedIOException (e );
249236 }
250- yamlBuilder .field (fieldName , toYaml (BlockUtils .toJavaObject (currentBlock , position )));
251237 }
252- return Strings .toString (yamlBuilder .endObject ());
238+
239+ return InferenceAction .Request .builder (inferenceId , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ();
240+ } finally {
241+ Releasables .closeExpectNoException (inputBlocks );
253242 }
254243 }
255244
@@ -262,7 +251,7 @@ private Object toYaml(Object value) {
262251 };
263252 } catch (Error | Exception e ) {
264253 // Swallow errors caused by invalid byteref.
265- return null ;
254+ return "" ;
266255 }
267256 }
268257}
0 commit comments