77
88package org .elasticsearch .xpack .esql .plan .physical .inference ;
99
10+ import org .apache .lucene .util .BytesRef ;
1011import org .elasticsearch .action .ActionListener ;
12+ import org .elasticsearch .common .Strings ;
13+ import org .elasticsearch .compute .data .Block ;
1114import org .elasticsearch .compute .data .BlockFactory ;
15+ import org .elasticsearch .compute .data .BlockUtils ;
16+ import org .elasticsearch .compute .data .DoubleBlock ;
17+ import org .elasticsearch .compute .data .ElementType ;
1218import org .elasticsearch .compute .data .Page ;
1319import org .elasticsearch .compute .operator .AsyncOperator ;
1420import org .elasticsearch .compute .operator .DriverContext ;
1521import org .elasticsearch .compute .operator .EvalOperator ;
1622import org .elasticsearch .compute .operator .Operator ;
23+ import org .elasticsearch .inference .TaskType ;
1724import org .elasticsearch .logging .LogManager ;
1825import org .elasticsearch .logging .Logger ;
26+ import org .elasticsearch .xcontent .XContentBuilder ;
27+ import org .elasticsearch .xcontent .XContentFactory ;
28+ import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
29+ import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
1930import org .elasticsearch .xpack .esql .inference .InferenceService ;
2031
32+ import java .io .IOException ;
2133import java .util .HashMap ;
34+ import java .util .List ;
2235import java .util .Map ;
2336
2437public class RerankOperator extends AsyncOperator <Page > {
@@ -101,7 +114,17 @@ public RerankOperator(
101114
102115 @ Override
103116 protected void performAsync (Page inputPage , ActionListener <Page > listener ) {
104- listener .onResponse (inputPage );
117+ try {
118+ inferenceService .doInference (
119+ buildInferenceRequest (inputPage ),
120+ ActionListener .wrap (
121+ (inferenceResponse ) -> listener .onResponse (buildOutput (inputPage , inferenceResponse )),
122+ listener ::onFailure
123+ )
124+ );
125+ } catch (IOException e ) {
126+ listener .onFailure (e );
127+ }
105128 }
106129
107130 @ Override
@@ -118,4 +141,78 @@ protected void releaseFetchedOnAnyThread(Page page) {
118141 public Page getOutput () {
119142 return fetchFromBuffer ();
120143 }
144+
145+ private Page buildOutput (Page inputPage , InferenceAction .Response inferenceResponse ) {
146+ int blockCount = inputPage .getBlockCount ();
147+ Block .Builder [] blocksBuilders = new Block .Builder [blockCount ];
148+
149+ for (int b = 0 ; b < blockCount ; b ++) {
150+ if (b == scoreChannel ) {
151+ blocksBuilders [b ] = ElementType .DOUBLE .newBlockBuilder (inputPage .getPositionCount (), blockFactory );
152+ } else {
153+ blocksBuilders [b ] = inputPage .getBlock (b ).elementType ().newBlockBuilder (inputPage .getPositionCount (), blockFactory );
154+ }
155+ }
156+
157+ if (inferenceResponse .getResults () instanceof RankedDocsResults rankedDocsResults ) {
158+ for (var rankedDoc : rankedDocsResults .getRankedDocs ()) {
159+ for (int b = 0 ; b < blockCount ; b ++) {
160+ if (b == scoreChannel ) {
161+ if (blocksBuilders [b ] instanceof DoubleBlock .Builder scoreBlockBuilder ) {
162+ scoreBlockBuilder .beginPositionEntry ().appendDouble (rankedDoc .relevanceScore ()).endPositionEntry ();
163+ }
164+ } else {
165+ blocksBuilders [b ].copyFrom (inputPage .getBlock (b ), rankedDoc .index (), rankedDoc .index () + 1 );
166+ }
167+ }
168+ }
169+
170+ return new Page (Block .Builder .buildAll (blocksBuilders ));
171+ }
172+
173+ throw new IllegalStateException (
174+ "Inference result has wrong type. Got ["
175+ + inferenceResponse .getResults ().getClass ()
176+ + "] while expecting ["
177+ + RankedDocsResults .class
178+ + "]"
179+ );
180+ }
181+
182+ private InferenceAction .Request buildInferenceRequest (Page inputPage ) throws IOException {
183+ String [] inputs = new String [inputPage .getPositionCount ()];
184+ Map <String , Block > inputBlocks = new HashMap <>();
185+
186+ for (var entry : rerankFieldsEvaluator .entrySet ()) {
187+ inputBlocks .put (entry .getKey (), entry .getValue ().eval (inputPage ));
188+ }
189+
190+ for (int pos = 0 ; pos < inputPage .getPositionCount (); pos ++) {
191+ inputs [pos ] = toYaml (inputBlocks , pos );
192+ }
193+
194+ return InferenceAction .Request .builder (inferenceId , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ();
195+ }
196+
197+ private String toYaml (Map <String , Block > inputBlocks , int position ) throws IOException {
198+ try (XContentBuilder yamlBuilder = XContentFactory .yamlBuilder ().startObject ()) {
199+ for (var blockEntry : inputBlocks .entrySet ()) {
200+ String fieldName = blockEntry .getKey ();
201+ Block currentBlock = blockEntry .getValue ();
202+ if (currentBlock .isNull (position )) {
203+ continue ;
204+ }
205+ yamlBuilder .field (fieldName , toYaml (BlockUtils .toJavaObject (currentBlock , position )));
206+ }
207+ return Strings .toString (yamlBuilder .endObject ());
208+ }
209+ }
210+
211+ private Object toYaml (Object value ) {
212+ return switch (value ) {
213+ case BytesRef b -> b .utf8ToString ();
214+ case List <?> l -> l .stream ().map (this ::toYaml ).toList ();
215+ default -> value ;
216+ };
217+ }
121218}
0 commit comments