99
1010import org .apache .lucene .util .BytesRef ;
1111import org .elasticsearch .action .ActionListener ;
12+ import org .elasticsearch .common .Strings ;
1213import org .elasticsearch .compute .data .Block ;
1314import org .elasticsearch .compute .data .BlockFactory ;
14- import org .elasticsearch .compute .data .BytesRefBlock ;
15+ import org .elasticsearch .compute .data .BlockUtils ;
1516import org .elasticsearch .compute .data .DoubleBlock ;
1617import org .elasticsearch .compute .data .ElementType ;
1718import org .elasticsearch .compute .data .Page ;
2122import org .elasticsearch .inference .TaskType ;
2223import org .elasticsearch .logging .LogManager ;
2324import org .elasticsearch .logging .Logger ;
25+ import org .elasticsearch .xcontent .XContentBuilder ;
26+ import org .elasticsearch .xcontent .XContentFactory ;
2427import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2528import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
2629
30+ import java .io .IOException ;
31+ import java .util .HashMap ;
2732import java .util .List ;
33+ import java .util .Map ;
2834
2935public class RerankOperator extends AsyncOperator <Page > {
3036
@@ -34,23 +40,39 @@ public record Factory(
3440 InferenceService inferenceService ,
3541 String inferenceId ,
3642 String queryText ,
37- EvalOperator .ExpressionEvaluator .Factory inputEvaluatorFactory ,
43+ Map < String , EvalOperator .ExpressionEvaluator .Factory > rerankFieldsEvaluatorSuppliers ,
3844 int scoreChannel ,
3945 int maxOutstandingRequests
4046 ) implements OperatorFactory {
4147 @ Override
4248 public RerankOperator get (DriverContext driverContext ) {
49+
50+
4351 return new RerankOperator (
4452 inferenceService ,
4553 inferenceId ,
4654 queryText ,
47- inputEvaluatorFactory . get ( driverContext ),
55+ buildRerankFieldEvaluator ( rerankFieldsEvaluatorSuppliers , driverContext ),
4856 scoreChannel ,
4957 driverContext ,
5058 maxOutstandingRequests
5159 );
5260 }
5361
62+
63+ private Map <String , EvalOperator .ExpressionEvaluator > buildRerankFieldEvaluator (
64+ Map <String , EvalOperator .ExpressionEvaluator .Factory > rerankFieldsEvaluatorSuppliers ,
65+ DriverContext driverContext
66+ ) {
67+ Map <String , EvalOperator .ExpressionEvaluator > rerankFieldsEvaluators = new HashMap <>();
68+
69+ for (var entry : rerankFieldsEvaluatorSuppliers .entrySet ()) {
70+ rerankFieldsEvaluators .put (entry .getKey (), entry .getValue ().get (driverContext ));
71+ }
72+
73+ return rerankFieldsEvaluators ;
74+ }
75+
5476 @ Override
5577 public String describe () {
5678 return "RerankOperator[maxOutstandingRequests = " + maxOutstandingRequests + "]" ;
@@ -61,14 +83,14 @@ public String describe() {
6183 private final BlockFactory blockFactory ;
6284 private final String inferenceId ;
6385 private final String queryText ;
64- private final EvalOperator .ExpressionEvaluator inputEvaluator ;
86+ private final Map < String , EvalOperator .ExpressionEvaluator > rerankFieldsEvaluator ;
6587 private final int scoreChannel ;
6688
6789 public RerankOperator (
6890 InferenceService inferenceService ,
6991 String inferenceId ,
7092 String queryText ,
71- EvalOperator .ExpressionEvaluator inputEvaluator ,
93+ Map < String , EvalOperator .ExpressionEvaluator > rerankFieldsEvaluator ,
7294 int scoreChannel ,
7395 DriverContext driverContext ,
7496 int maxOutstandingRequests
@@ -78,7 +100,7 @@ public RerankOperator(
78100 this .blockFactory = driverContext .blockFactory ();
79101 this .inferenceId = inferenceId ;
80102 this .queryText = queryText ;
81- this .inputEvaluator = inputEvaluator ;
103+ this .rerankFieldsEvaluator = rerankFieldsEvaluator ;
82104 this .scoreChannel = scoreChannel ;
83105 }
84106
@@ -100,9 +122,13 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
100122 queryText ,
101123 inputPage .getPositionCount ()
102124 );
103- inferenceService .infer (buildInferenceRequest (inputPage ), ActionListener .wrap ((inferenceResponse ) -> {
104- listener .onResponse (buildOutput (inputPage , inferenceResponse ));
105- }, listener ::onFailure ));
125+ try {
126+ inferenceService .infer (buildInferenceRequest (inputPage ), ActionListener .wrap ((inferenceResponse ) -> {
127+ listener .onResponse (buildOutput (inputPage , inferenceResponse ));
128+ }, listener ::onFailure ));
129+ } catch (IOException e ) {
130+ listener .onFailure (e );
131+ }
106132 }
107133
108134 @ Override
@@ -151,19 +177,40 @@ private Page buildOutput(Page inputPage, InferenceAction.Response inferenceRespo
151177 );
152178 }
153179
154- private InferenceAction .Request buildInferenceRequest (Page inputPage ) {
155- BytesRef scratch = new BytesRef ();
180+ private InferenceAction .Request buildInferenceRequest (Page inputPage ) throws IOException {
156181 String [] inputs = new String [inputPage .getPositionCount ()];
157- BytesRefBlock inputBlock = (BytesRefBlock ) inputEvaluator .eval (inputPage );
182+ Map <String , Block > inputBlocks = new HashMap <>();
183+
184+
185+ for (var entry :rerankFieldsEvaluator .entrySet ()) {
186+ inputBlocks .put (entry .getKey (), entry .getValue ().eval (inputPage ));
187+ };
158188
159189 for (int pos = 0 ; pos < inputPage .getPositionCount (); pos ++) {
160- if (inputBlock .isNull (pos ) || inputBlock .getValueCount (pos ) > 1 ) {
161- inputs [pos ] = "" ;
162- } else {
163- inputs [pos ] = inputBlock .getBytesRef (pos , scratch ).utf8ToString ();
190+ try (XContentBuilder yamlBuilder = XContentFactory .yamlBuilder ().startObject ()) {
191+ for (var blockEntry : inputBlocks .entrySet ()) {
192+ String fieldName = blockEntry .getKey ();
193+ Block currentBlock = blockEntry .getValue ();
194+ if (currentBlock .isNull (pos )) {
195+ continue ;
196+ }
197+ Object value = BlockUtils .toJavaObject (currentBlock , pos );
198+ yamlBuilder .field (fieldName , toYamlValue (value ));
199+ }
200+ yamlBuilder .endObject ();
201+ inputs [pos ] = Strings .toString (yamlBuilder );
164202 }
165203 }
166204
205+
167206 return InferenceAction .Request .builder (inferenceId , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ();
168207 }
208+
209+ private Object toYamlValue (Object value ) {
210+ return switch (value ) {
211+ case BytesRef b -> b .utf8ToString ();
212+ case List <?> l -> l .stream ().map (this ::toYamlValue );
213+ default -> value ;
214+ };
215+ }
169216}
0 commit comments