88package org .elasticsearch .xpack .esql .inference ;
99
1010import org .elasticsearch .action .ActionListener ;
11+ import org .elasticsearch .compute .data .BlockFactory ;
1112import org .elasticsearch .compute .data .Page ;
1213import org .elasticsearch .compute .operator .AsyncOperator ;
1314import org .elasticsearch .compute .operator .DriverContext ;
14- import org .elasticsearch .core .CheckedConsumer ;
15- import org .elasticsearch .core .Releasable ;
16- import org .elasticsearch .core .Releasables ;
1715import org .elasticsearch .inference .InferenceServiceResults ;
18- import org .elasticsearch .xpack . core . inference . action . InferenceAction ;
19-
20- import java . util . Iterator ;
21-
22- import static org .elasticsearch .common . logging . LoggerMessageFormat . format ;
16+ import org .elasticsearch .threadpool . ThreadPool ;
17+ import org . elasticsearch . xpack . esql . inference . bulk . BulkInferenceExecutionConfig ;
18+ import org . elasticsearch . xpack . esql . inference . bulk . BulkInferenceExecutor ;
19+ import org . elasticsearch . xpack . esql . inference . bulk . BulkInferenceOutputBuilder ;
20+ import org .elasticsearch .xpack . esql . inference . bulk . BulkInferenceRequestIterator ;
2321
2422public abstract class InferenceOperator <InferenceResult extends InferenceServiceResults > extends AsyncOperator <Page > {
2523
2624 // Move to a setting.
2725 private static final int MAX_INFERENCE_WORKER = 10 ;
28- private final InferenceRunner inferenceRunner ;
2926 private final String inferenceId ;
27+ private final BlockFactory blockFactory ;
3028
31- public InferenceOperator (DriverContext driverContext , InferenceRunner inferenceRunner , String inferenceId ) {
32- super (driverContext , inferenceRunner .threadContext (), MAX_INFERENCE_WORKER );
33- this .inferenceRunner = inferenceRunner ;
29+ private final BulkInferenceExecutor <InferenceResult , Page > bulkInferenceExecutor ;
30+
31+ @ SuppressWarnings ("this-escape" )
32+ public InferenceOperator (DriverContext driverContext , InferenceRunner inferenceRunner , ThreadPool threadPool , String inferenceId ) {
33+ super (driverContext , threadPool .getThreadContext (), MAX_INFERENCE_WORKER );
34+ this .blockFactory = driverContext .blockFactory ();
35+ this .bulkInferenceExecutor = new BulkInferenceExecutor <>(inferenceRunner , threadPool , bulkExecutionConfig ());
3436 this .inferenceId = inferenceId ;
3537 }
3638
39+ protected BlockFactory blockFactory () {
40+ return blockFactory ;
41+ }
42+
3743 protected String inferenceId () {
3844 return inferenceId ;
3945 }
@@ -50,52 +56,14 @@ public Page getOutput() {
5056
5157 @ Override
5258 protected void performAsync (Page input , ActionListener <Page > listener ) {
53- final RequestIterator requests = requests (input );
54- final OutputBuilder <InferenceResult > outputBuilder = outputBuilder (input );
55-
56- new BulkInferenceOperation (requests , outputBuilder ).execute (
57- inferenceExecutionContext (),
58- listener .delegateFailureIgnoreResponseAndWrap (l -> {
59- l .onResponse (outputBuilder .buildOutput ());
60- Releasables .closeExpectNoException (requests , outputBuilder );
61- })
62- );
59+ bulkInferenceExecutor .execute (requests (input ), outputBuilder (input ), listener );
6360 }
6461
65- protected InferenceExecutionContext inferenceExecutionContext () {
66- return inferenceRunner . executionContextBuilder (). build () ;
62+ protected BulkInferenceExecutionConfig bulkExecutionConfig () {
63+ return BulkInferenceExecutionConfig . DEFAULT ;
6764 }
6865
69- protected abstract RequestIterator requests (Page input );
70-
71- protected abstract OutputBuilder <InferenceResult > outputBuilder (Page input );
72-
73- public abstract static class OutputBuilder <InferenceResults extends InferenceServiceResults >
74- implements
75- CheckedConsumer <InferenceAction .Response , Exception >,
76- Releasable {
77- protected abstract Class <InferenceResults > inferenceResultsClass ();
78-
79- public abstract Page buildOutput ();
80-
81- public abstract void onInferenceResults (InferenceResults results );
82-
83- @ Override
84- public void accept (InferenceAction .Response response ) throws Exception {
85- InferenceServiceResults results = response .getResults ();
86- if (inferenceResultsClass ().isInstance (response .getResults ()) == false ) {
87- throw new IllegalStateException (
88- format (
89- "Inference result has wrong type. Got [{}] while expecting [{}]" ,
90- results .getClass ().getName (),
91- inferenceResultsClass ().getName ()
92- )
93- );
94- }
95-
96- onInferenceResults (inferenceResultsClass ().cast (results ));
97- }
98- }
66+ protected abstract BulkInferenceRequestIterator requests (Page input );
9967
100- public interface RequestIterator extends Iterator < InferenceAction . Request >, Releasable {}
68+ protected abstract BulkInferenceOutputBuilder < InferenceResult , Page > outputBuilder ( Page input );
10169}
0 commit comments