99
1010import org .apache .lucene .util .BytesRef ;
1111import org .elasticsearch .action .ActionListener ;
12- import org .elasticsearch .common .Strings ;
1312import org .elasticsearch .compute .data .Block ;
1413import org .elasticsearch .compute .data .BlockFactory ;
15- import org .elasticsearch .compute .data .BlockUtils ;
14+ import org .elasticsearch .compute .data .BytesRefBlock ;
1615import org .elasticsearch .compute .data .DoubleBlock ;
1716import org .elasticsearch .compute .data .Page ;
1817import org .elasticsearch .compute .operator .AsyncOperator ;
2120import org .elasticsearch .compute .operator .Operator ;
2221import org .elasticsearch .core .Releasables ;
2322import org .elasticsearch .inference .TaskType ;
24- import org .elasticsearch .xcontent .XContentBuilder ;
25- import org .elasticsearch .xcontent .XContentFactory ;
2623import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2724import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
2825
29- import java .io .IOException ;
30- import java .io .UncheckedIOException ;
3126import java .util .List ;
3227import java .util .Map ;
3328
29+ import static org .elasticsearch .xpack .esql .inference .XContentRowEncoder .yamlRowEncoderFactory ;
30+
3431public class RerankOperator extends AsyncOperator <Page > {
3532
3633 // Move to a setting.
@@ -46,15 +43,15 @@ public record Factory(
4643
4744 @ Override
4845 public String describe () {
49- return "RerankOperator[inference_id="
46+ return "RerankOperator[inference_id=[ "
5047 + inferenceId
51- + " query="
48+ + "], query=[ "
5249 + queryText
53- + " rerank_fields="
50+ + "], rerank_fields="
5451 + fieldsEvaluatorFactories .keySet ()
55- + " score_channel="
52+ + ", score_channel=[ "
5653 + scoreChannel
57- + "]" ;
54+ + "]] " ;
5855 }
5956
6057 @ Override
@@ -64,56 +61,43 @@ public Operator get(DriverContext driverContext) {
6461 inferenceService ,
6562 inferenceId ,
6663 queryText ,
67- fieldNames (),
68- fieldsEvaluators (driverContext ),
64+ yamlRowEncoderFactory (fieldsEvaluatorFactories ).get (driverContext ),
6965 scoreChannel
7066 );
7167 }
72-
73- private String [] fieldNames () {
74- return fieldsEvaluatorFactories .keySet ().toArray (String []::new );
75- }
76-
77- private ExpressionEvaluator [] fieldsEvaluators (DriverContext context ) {
78- return fieldsEvaluatorFactories .values ().stream ().map (factory -> factory .get (context )).toArray (ExpressionEvaluator []::new );
79- }
8068 }
8169
8270 private final InferenceService inferenceService ;
8371 private final BlockFactory blockFactory ;
8472 private final String inferenceId ;
8573 private final String queryText ;
86- private final String [] fieldNames ;
87- private final ExpressionEvaluator [] fieldsEvaluators ;
74+ private final ExpressionEvaluator rowEncoder ;
8875 private final int scoreChannel ;
8976
9077 public RerankOperator (
9178 DriverContext driverContext ,
9279 InferenceService inferenceService ,
9380 String inferenceId ,
9481 String queryText ,
95- String [] fieldNames ,
96- ExpressionEvaluator [] fieldsEvaluators ,
82+ ExpressionEvaluator rowEncoder ,
9783 int scoreChannel
9884 ) {
9985 super (driverContext , inferenceService .getThreadContext (), MAX_INFERENCE_WORKER );
10086
10187 assert inferenceService .getThreadContext () != null ;
102- assert fieldNames .length == fieldsEvaluators .length ;
10388
10489 this .blockFactory = driverContext .blockFactory ();
10590 this .inferenceService = inferenceService ;
10691 this .inferenceId = inferenceId ;
10792 this .queryText = queryText ;
108- this .fieldNames = fieldNames ;
109- this .fieldsEvaluators = fieldsEvaluators ;
93+ this .rowEncoder = rowEncoder ;
11094 this .scoreChannel = scoreChannel ;
11195 }
11296
11397 @ Override
11498 protected void performAsync (Page inputPage , ActionListener <Page > listener ) {
11599 // Ensure input page blocks are released when the listener is called.
116- final ActionListener <Page > outputListener = ActionListener .runAfter (listener , () -> { inputPage . releaseBlocks ( ); });
100+ final ActionListener <Page > outputListener = ActionListener .runAfter (listener , () -> { releasePageOnAnyThread ( inputPage ); });
117101
118102 try {
119103 inferenceService .doInference (
@@ -130,7 +114,7 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
130114
131115 @ Override
132116 protected void doClose () {
133- Releasables .closeExpectNoException (this . fieldsEvaluators );
117+ Releasables .closeExpectNoException (rowEncoder );
134118 }
135119
136120 @ Override
@@ -145,15 +129,15 @@ public Page getOutput() {
145129
146130 @ Override
147131 public String toString () {
148- return "RerankOperator[inference_id="
132+ return "RerankOperator[inference_id=[ "
149133 + inferenceId
150- + " query="
134+ + "], query=[ "
151135 + queryText
152- + " rerank_fields= "
153- + List . of ( fieldNames )
154- + " score_channel="
136+ + "], row_encoder=[ "
137+ + rowEncoder
138+ + "], score_channel=[ "
155139 + scoreChannel
156- + "]" ;
140+ + "]] " ;
157141 }
158142
159143 private Page buildOutput (Page inputPage , InferenceAction .Response inferenceResponse ) {
@@ -212,46 +196,21 @@ private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResult
212196 }
213197
214198 private InferenceAction .Request buildInferenceRequest (Page inputPage ) {
215- Block [] inputBlocks = new Block [fieldsEvaluators .length ];
216-
217- try {
218- for (int b = 0 ; b < inputBlocks .length ; b ++) {
219- inputBlocks [b ] = fieldsEvaluators [b ].eval (inputPage );
220- }
221-
199+ try (BytesRefBlock encodedRowBlock = (BytesRefBlock ) rowEncoder .eval (inputPage )) {
200+ assert (encodedRowBlock .getPositionCount () == inputPage .getPositionCount ());
222201 String [] inputs = new String [inputPage .getPositionCount ()];
202+ BytesRef buffer = new BytesRef ();
203+
223204 for (int pos = 0 ; pos < inputPage .getPositionCount (); pos ++) {
224- try (XContentBuilder yamlBuilder = XContentFactory .yamlBuilder ().startObject ()) {
225- for (int i = 0 ; i < inputBlocks .length ; i ++) {
226- String fieldName = fieldNames [i ];
227- Block currentBlock = inputBlocks [i ];
228- if (currentBlock .isNull (pos )) {
229- continue ;
230- }
231- yamlBuilder .field (fieldName , toYaml (BlockUtils .toJavaObject (currentBlock , pos )));
232- }
233- inputs [pos ] = Strings .toString (yamlBuilder .endObject ());
234- } catch (IOException e ) {
235- throw new UncheckedIOException (e );
205+ if (encodedRowBlock .isNull (pos )) {
206+ inputs [pos ] = "" ;
207+ } else {
208+ buffer = encodedRowBlock .getBytesRef (encodedRowBlock .getFirstValueIndex (pos ), buffer );
209+ inputs [pos ] = buffer .utf8ToString ();
236210 }
237211 }
238212
239213 return InferenceAction .Request .builder (inferenceId , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ();
240- } finally {
241- Releasables .closeExpectNoException (inputBlocks );
242- }
243- }
244-
245- private Object toYaml (Object value ) {
246- try {
247- return switch (value ) {
248- case BytesRef b -> b .utf8ToString ();
249- case List <?> l -> l .stream ().map (this ::toYaml ).toList ();
250- default -> value ;
251- };
252- } catch (Error | Exception e ) {
253- // Swallow errors caused by invalid byteref.
254- return "" ;
255214 }
256215 }
257216}
0 commit comments