1515import org .elasticsearch .compute .data .BytesRefBlock ;
1616import org .elasticsearch .compute .data .DoubleBlock ;
1717import org .elasticsearch .compute .data .Page ;
18- import org .elasticsearch .compute .operator .AsyncOperator ;
1918import org .elasticsearch .compute .operator .DriverContext ;
2019import org .elasticsearch .compute .operator .EvalOperator .ExpressionEvaluator ;
2120import org .elasticsearch .compute .operator .Operator ;
2625
2726import java .util .List ;
2827
29- public class RerankOperator extends AsyncOperator <Page > {
30-
31- // Move to a setting.
32- private static final int MAX_INFERENCE_WORKER = 10 ;
33-
28+ public class RerankOperator extends InferenceOperator <Page > {
3429 public record Factory (
3530 InferenceRunner inferenceRunner ,
3631 String inferenceId ,
@@ -57,9 +52,7 @@ public Operator get(DriverContext driverContext) {
5752 }
5853 }
5954
60- private final InferenceRunner inferenceRunner ;
6155 private final BlockFactory blockFactory ;
62- private final String inferenceId ;
6356 private final String queryText ;
6457 private final ExpressionEvaluator rowEncoder ;
6558 private final int scoreChannel ;
@@ -72,13 +65,9 @@ public RerankOperator(
7265 ExpressionEvaluator rowEncoder ,
7366 int scoreChannel
7467 ) {
75- super (driverContext , inferenceRunner .getThreadContext (), MAX_INFERENCE_WORKER );
76-
77- assert inferenceRunner .getThreadContext () != null ;
68+ super (driverContext , inferenceRunner .getThreadContext (), inferenceRunner , inferenceId );
7869
7970 this .blockFactory = driverContext .blockFactory ();
80- this .inferenceRunner = inferenceRunner ;
81- this .inferenceId = inferenceId ;
8271 this .queryText = queryText ;
8372 this .rowEncoder = rowEncoder ;
8473 this .scoreChannel = scoreChannel ;
@@ -90,7 +79,7 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
9079 final ActionListener <Page > outputListener = ActionListener .runAfter (listener , () -> { releasePageOnAnyThread (inputPage ); });
9180
9281 try {
93- inferenceRunner . doInference (
82+ doInference (
9483 buildInferenceRequest (inputPage ),
9584 ActionListener .wrap (
9685 inferenceResponse -> outputListener .onResponse (buildOutput (inputPage , inferenceResponse )),
@@ -119,7 +108,7 @@ public Page getOutput() {
119108
120109 @ Override
121110 public String toString () {
122- return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]" ;
111+ return "RerankOperator[inference_id=[" + inferenceId () + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]" ;
123112 }
124113
125114 private Page buildOutput (Page inputPage , InferenceAction .Response inferenceResponse ) {
@@ -192,7 +181,7 @@ private InferenceAction.Request buildInferenceRequest(Page inputPage) {
192181 }
193182 }
194183
195- return InferenceAction .Request .builder (inferenceId , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ();
184+ return InferenceAction .Request .builder (inferenceId () , TaskType .RERANK ).setInput (List .of (inputs )).setQuery (queryText ).build ();
196185 }
197186 }
198187}
0 commit comments