1313import org .elasticsearch .compute .operator .AsyncOperator ;
1414import org .elasticsearch .compute .operator .DriverContext ;
1515import org .elasticsearch .core .Releasable ;
16+ import org .elasticsearch .core .Releasables ;
1617import org .elasticsearch .inference .InferenceServiceResults ;
1718import org .elasticsearch .threadpool .ThreadPool ;
1819import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2425
2526import static org .elasticsearch .common .logging .LoggerMessageFormat .format ;
2627
27- public abstract class InferenceOperator extends AsyncOperator <InferenceOperator .OngoingInference > {
28+ /**
29+ * An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceRunner}.
30+ * <p>
31+ * The {@code InferenceOperator} integrates with the compute framework supports throttled bulk execution of inference requests. It
32+ * transforms input {@link Page} into inference requests, asynchronously executes them, and converts the responses into a new {@link Page}.
33+ * </p>
34+ */
35+ public abstract class InferenceOperator extends AsyncOperator <InferenceOperator .OngoingInferenceResult > {
2836 private final String inferenceId ;
2937 private final BlockFactory blockFactory ;
3038 private final BulkInferenceExecutor bulkInferenceExecutor ;
3139
40+ /**
41+ * Constructs a new {@code InferenceOperator}.
42+ *
43+ * @param driverContext The driver context.
44+ * @param inferenceRunner The runner used to execute inference requests.
45+ * @param bulkExecutionConfig Configuration for inference execution.
46+ * @param threadPool The thread pool used for executing async inference.
47+ * @param inferenceId The ID of the inference model to use.
48+ */
3249 public InferenceOperator (
3350 DriverContext driverContext ,
3451 InferenceRunner inferenceRunner ,
3552 BulkInferenceExecutionConfig bulkExecutionConfig ,
3653 ThreadPool threadPool ,
3754 String inferenceId
3855 ) {
39- super (driverContext , threadPool .getThreadContext (), bulkExecutionConfig .workers ());
56+ super (driverContext , inferenceRunner . threadPool () .getThreadContext (), bulkExecutionConfig .workers ());
4057 this .blockFactory = driverContext .blockFactory ();
4158 this .bulkInferenceExecutor = new BulkInferenceExecutor (inferenceRunner , threadPool , bulkExecutionConfig );
4259 this .inferenceId = inferenceId ;
4360 }
4461
62+ /**
63+ * Returns the {@link BlockFactory} used to create output data blocks.
64+ */
4565 protected BlockFactory blockFactory () {
4666 return blockFactory ;
4767 }
4868
69+ /**
70+ * Returns the inference model ID used for this operator.
71+ */
4972 protected String inferenceId () {
5073 return inferenceId ;
5174 }
5275
76+ /**
77+ * Initiates asynchronous inferences for the given input page.
78+ */
5379 @ Override
54- protected void releaseFetchedOnAnyThread (OngoingInference result ) {
55- releasePageOnAnyThread (result .inputPage );
56- }
57-
58- @ Override
59- protected void performAsync (Page input , ActionListener <OngoingInference > listener ) {
80+ protected void performAsync (Page input , ActionListener <OngoingInferenceResult > listener ) {
6081 try {
6182 BulkInferenceRequestIterator requests = requests (input );
6283 listener = ActionListener .releaseBefore (requests , listener );
63- bulkInferenceExecutor .execute (requests , listener .map (responses -> new OngoingInference (input , responses )));
84+ bulkInferenceExecutor .execute (requests , listener .map (responses -> new OngoingInferenceResult (input , responses )));
6485 } catch (Exception e ) {
6586 listener .onFailure (e );
6687 }
6788 }
6889
90+ /**
91+ * Releases resources associated with an ongoing inference.
92+ */
93+ @ Override
94+ protected void releaseFetchedOnAnyThread (OngoingInferenceResult ongoingInferenceResult ) {
95+ Releasables .close (ongoingInferenceResult );
96+ }
97+
98+ /**
99+ * Returns the next available output page constructed from completed inference results.
100+ */
69101 @ Override
70102 public Page getOutput () {
71- OngoingInference ongoingInference = fetchFromBuffer ();
72- if (ongoingInference == null ) {
103+ OngoingInferenceResult ongoingInferenceResult = fetchFromBuffer ();
104+ if (ongoingInferenceResult == null ) {
73105 return null ;
74106 }
75107
76- try (OutputBuilder outputBuilder = outputBuilder (ongoingInference .inputPage )) {
77- ongoingInference .responses .forEach (outputBuilder ::addInferenceResponse );
108+ try (OutputBuilder outputBuilder = outputBuilder (ongoingInferenceResult .inputPage )) {
109+ assert ongoingInferenceResult .inputPage .getPositionCount () == ongoingInferenceResult .responses .size ();
110+ for (InferenceAction .Response response : ongoingInferenceResult .responses ) {
111+ try {
112+ outputBuilder .addInferenceResponse (response );
113+ } catch (IllegalArgumentException e ) {
114+ throw new IllegalStateException ("Invalid inference response" , e );
115+ }
116+ }
78117 return outputBuilder .buildOutput ();
118+
79119 } finally {
80- releaseFetchedOnAnyThread (ongoingInference );
120+ releaseFetchedOnAnyThread (ongoingInferenceResult );
81121 }
82122 }
83123
124+ /**
125+ * Converts the given input page into a sequence of inference requests.
126+ *
127+ * @param input The input page to process.
128+ */
84129 protected abstract BulkInferenceRequestIterator requests (Page input );
85130
131+ /**
132+ * Creates a new {@link OutputBuilder} instance used to build the output page.
133+ *
134+ * @param input The corresponding input page used to generate the inference requests.
135+ */
86136 protected abstract OutputBuilder outputBuilder (Page input );
87137
138+ /**
139+ * An interface for accumulating inference responses and constructing a result {@link Page}.
140+ */
88141 public interface OutputBuilder extends Releasable {
142+
143+ /**
144+ * Adds an inference response to the output.
145+ * <p>
146+ * The responses must be added in the same order as the corresponding inference requests were generated.
147+ * Failing to preserve order may lead to incorrect or misaligned output rows.
148+ * </p>
149+ *
150+ * @param inferenceResponse The inference response to include.
151+ */
89152 void addInferenceResponse (InferenceAction .Response inferenceResponse );
90153
154+ /**
155+ * Builds the final output page from accumulated inference responses.
156+ *
157+ * @return The constructed output page.
158+ */
91159 Page buildOutput ();
92160
93161 static <IR extends InferenceServiceResults > IR inferenceResults (InferenceAction .Response inferenceResponse , Class <IR > clazz ) {
@@ -102,7 +170,18 @@ static <IR extends InferenceServiceResults> IR inferenceResults(InferenceAction.
102170 }
103171 }
104172
105- public record OngoingInference (Page inputPage , List <InferenceAction .Response > responses ) {
106-
173+ /**
174+ * Represents the result of an ongoing inference operation, including the original input page
175+ * and the list of inference responses.
176+ *
177+ * @param inputPage The input page used to generate inference requests.
178+ * @param responses The inference responses returned by the inference service.
179+ */
180+ public record OngoingInferenceResult (Page inputPage , List <InferenceAction .Response > responses ) implements Releasable {
181+
182+ @ Override
183+ public void close () {
184+ releasePageOnAnyThread (inputPage );
185+ }
107186 }
108187}
0 commit comments