88package org .elasticsearch .xpack .esql .inference .bulk ;
99
1010import org .elasticsearch .action .ActionListener ;
11- import org .elasticsearch .common .util .concurrent .ThrottledTaskRunner ;
11+ import org .elasticsearch .common .util .concurrent .AbstractRunnable ;
12+ import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
1213import org .elasticsearch .threadpool .ThreadPool ;
1314import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
1415import org .elasticsearch .xpack .esql .inference .InferenceRunner ;
1516import org .elasticsearch .xpack .esql .plugin .EsqlPlugin ;
1617
1718import java .util .ArrayList ;
1819import java .util .List ;
20+ import java .util .concurrent .BlockingQueue ;
1921import java .util .concurrent .ExecutorService ;
20- import java .util .concurrent .RejectedExecutionException ;
21- import java .util .concurrent .TimeoutException ;
22+ import java .util .concurrent .Semaphore ;
23+ import java .util .concurrent .atomic . AtomicBoolean ;
2224
2325public class BulkInferenceExecutor {
24- private static final String TASK_RUNNER_NAME = "bulk_inference_operation" ;
25- private static final int INFERENCE_RESPONSE_TIMEOUT = 30 ; // TODO: should be in the config.
2626 private final ThrottledInferenceRunner throttledInferenceRunner ;
27- private final ExecutorService executorService ;
2827
2928 public BulkInferenceExecutor (InferenceRunner inferenceRunner , ThreadPool threadPool , BulkInferenceExecutionConfig bulkExecutionConfig ) {
30- executorService = executorService (threadPool );
31- throttledInferenceRunner = ThrottledInferenceRunner .create (inferenceRunner , executorService , bulkExecutionConfig );
29+ throttledInferenceRunner = ThrottledInferenceRunner .create (inferenceRunner , executorService (threadPool ), bulkExecutionConfig );
3230 }
3331
3432 public void execute (BulkInferenceRequestIterator requests , ActionListener <List <InferenceAction .Response >> listener ) throws Exception {
35- final ResponseHandler responseHandler = new ResponseHandler ();
36- runInferenceRequests (requests , listener .delegateFailureAndWrap (responseHandler ::handleResponses ));
37- }
33+ if (requests .hasNext () == false ) {
34+ listener .onResponse (List .of ());
35+ return ;
36+ }
3837
39- private void runInferenceRequests (BulkInferenceRequestIterator requests , ActionListener <BulkInferenceExecutionState > listener ) {
4038 final BulkInferenceExecutionState bulkExecutionState = new BulkInferenceExecutionState ();
41- try {
42- executorService .execute (() -> {
43- while (bulkExecutionState .finished () == false && requests .hasNext ()) {
44- InferenceAction .Request request = requests .next ();
45- long seqNo = bulkExecutionState .generateSeqNo ();
46- throttledInferenceRunner .doInference (
47- request ,
48- ActionListener .wrap (
49- r -> bulkExecutionState .onInferenceResponse (seqNo , r ),
50- e -> bulkExecutionState .onInferenceException (seqNo , e )
51- )
52- );
53- }
39+ final ResponseHandler responseHandler = new ResponseHandler (bulkExecutionState , listener , requests .estimatedSize ());
40+
41+ while (bulkExecutionState .finished () == false && requests .hasNext ()) {
42+ InferenceAction .Request request = requests .next ();
43+ long seqNo = bulkExecutionState .generateSeqNo ();
44+
45+ if (requests .hasNext () == false ) {
5446 bulkExecutionState .finish ();
55- });
56- } catch (RejectedExecutionException e ) {
57- bulkExecutionState .addFailure (new IllegalStateException ("Unable to enqueue inference requests" , e ));
58- bulkExecutionState .finish ();
59- } finally {
60- listener .onResponse (bulkExecutionState );
47+ }
48+
49+ throttledInferenceRunner .doInference (
50+ request ,
51+ ActionListener .runAfter (
52+ ActionListener .wrap (
53+ r -> bulkExecutionState .onInferenceResponse (seqNo , r ),
54+ e -> bulkExecutionState .onInferenceException (seqNo , e )
55+ ),
56+ responseHandler ::persistsInferenceResponses
57+ )
58+ );
6159 }
6260 }
6361
6462 private static class ResponseHandler {
65- private final List <InferenceAction .Response > responses = new ArrayList <>();
66-
67- private void handleResponses (ActionListener <List <InferenceAction .Response >> listener , BulkInferenceExecutionState bulkExecutionState ) {
68-
69- try {
70- persistsInferenceResponses (bulkExecutionState );
71- } catch (InterruptedException | TimeoutException e ) {
72- bulkExecutionState .addFailure (e );
73- bulkExecutionState .finish ();
74- }
63+ private final List <InferenceAction .Response > responses ;
64+ private final ActionListener <List <InferenceAction .Response >> listener ;
65+ private final BulkInferenceExecutionState bulkExecutionState ;
66+ private final AtomicBoolean responseSent = new AtomicBoolean (false );
67+
68+ private ResponseHandler (
69+ BulkInferenceExecutionState bulkExecutionState ,
70+ ActionListener <List <InferenceAction .Response >> listener ,
71+ int estimatedSize
72+ ) {
73+ this .listener = listener ;
74+ this .bulkExecutionState = bulkExecutionState ;
75+ this .responses = new ArrayList <>(estimatedSize );
76+ }
7577
76- if (bulkExecutionState .hasFailure () == false ) {
77- try {
78- listener .onResponse (responses );
79- return ;
80- } catch (Exception e ) {
81- bulkExecutionState .addFailure (e );
78+ public synchronized void persistsInferenceResponses () {
79+ long persistedSeqNo = bulkExecutionState .getPersistedCheckpoint ();
80+
81+ while (persistedSeqNo < bulkExecutionState .getProcessedCheckpoint ()) {
82+ persistedSeqNo ++;
83+ InferenceAction .Response response = bulkExecutionState .fetchBufferedResponse (persistedSeqNo );
84+ assert response != null || bulkExecutionState .hasFailure ();
85+ if (bulkExecutionState .hasFailure () == false ) {
86+ try {
87+ responses .add (response );
88+ } catch (Exception e ) {
89+ bulkExecutionState .addFailure (e );
90+ }
8291 }
92+ bulkExecutionState .markSeqNoAsPersisted (persistedSeqNo );
8393 }
8494
85- listener . onFailure ( bulkExecutionState . getFailure () );
95+ sendResponseOnCompletion ( );
8696 }
8797
88- private void persistsInferenceResponses (BulkInferenceExecutionState bulkExecutionState ) throws TimeoutException ,
89- InterruptedException {
90- while (bulkExecutionState .finished () == false && bulkExecutionState .fetchProcessedSeqNo (INFERENCE_RESPONSE_TIMEOUT ) >= 0 ) {
91- long persistedSeqNo = bulkExecutionState .getPersistedCheckpoint ();
92-
93- while (persistedSeqNo < bulkExecutionState .getProcessedCheckpoint ()) {
94- persistedSeqNo ++;
95- InferenceAction .Response response = bulkExecutionState .fetchBufferedResponse (persistedSeqNo );
96- assert response != null || bulkExecutionState .hasFailure ();
97- if (bulkExecutionState .hasFailure () == false ) {
98- try {
99- responses .add (response );
100- } catch (Exception e ) {
101- bulkExecutionState .addFailure (e );
102- }
98+ private void sendResponseOnCompletion () {
99+ if (bulkExecutionState .finished () && responseSent .compareAndSet (false , true )) {
100+ if (bulkExecutionState .hasFailure () == false ) {
101+ try {
102+ listener .onResponse (responses );
103+ return ;
104+ } catch (Exception e ) {
105+ bulkExecutionState .addFailure (e );
103106 }
104- bulkExecutionState .markSeqNoAsPersisted (persistedSeqNo );
105107 }
108+
109+ listener .onFailure (bulkExecutionState .getFailure ());
106110 }
107111 }
108112 }
109113
110- private static class ThrottledInferenceRunner extends ThrottledTaskRunner {
114+ private static class ThrottledInferenceRunner {
111115 private final InferenceRunner inferenceRunner ;
116+ private final ExecutorService executorService ;
117+ private final BlockingQueue <AbstractRunnable > pendingRequests = ConcurrentCollections .newBlockingQueue ();
118+ private final Semaphore permits ;
112119
113120 private ThrottledInferenceRunner (InferenceRunner inferenceRunner , ExecutorService executorService , int maxRunningTasks ) {
114- super (TASK_RUNNER_NAME , maxRunningTasks , executorService );
121+ this .executorService = executorService ;
122+ this .permits = new Semaphore (maxRunningTasks );
115123 this .inferenceRunner = inferenceRunner ;
116124 }
117125
@@ -120,13 +128,58 @@ public static ThrottledInferenceRunner create(
120128 ExecutorService executorService ,
121129 BulkInferenceExecutionConfig bulkExecutionConfig
122130 ) {
123- return new ThrottledInferenceRunner (inferenceRunner , executorService , bulkExecutionConfig .workers ());
131+ return new ThrottledInferenceRunner (inferenceRunner , executorService , bulkExecutionConfig .maxOutstandingRequests ());
124132 }
125133
126134 public void doInference (InferenceAction .Request request , ActionListener <InferenceAction .Response > listener ) {
127- this .enqueueTask (listener .delegateFailureAndWrap ((l , releasable ) -> {
128- inferenceRunner .doInference (request , ActionListener .releaseAfter (l , releasable ));
129- }));
135+ enqueueTask (request , listener );
136+ executePendingRequests ();
137+ }
138+
139+ private void executePendingRequests () {
140+ while (permits .tryAcquire ()) {
141+ AbstractRunnable task = pendingRequests .poll ();
142+
143+ if (task == null ) {
144+ permits .release ();
145+ return ;
146+ }
147+
148+ try {
149+ executorService .execute (task );
150+ } catch (Exception e ){
151+ task .onFailure (e );
152+ permits .release ();
153+ }
154+ }
155+ }
156+
157+ private void enqueueTask (InferenceAction .Request request , ActionListener <InferenceAction .Response > listener ) {
158+ try {
159+ pendingRequests .add (createTask (request , listener ));
160+ executePendingRequests ();
161+ } catch (Exception e ) {
162+ listener .onFailure (new IllegalStateException ("An error occurred while adding the inference request to the queue" , e ));
163+ }
164+ }
165+
166+ private AbstractRunnable createTask (InferenceAction .Request request , ActionListener <InferenceAction .Response > listener ) {
167+ final ActionListener <InferenceAction .Response > completionListener = ActionListener .runAfter (listener , () -> {
168+ permits .release ();
169+ executePendingRequests ();
170+ });
171+
172+ return new AbstractRunnable () {
173+ @ Override
174+ protected void doRun () {
175+ inferenceRunner .doInference (request , completionListener );
176+ }
177+
178+ @ Override
179+ public void onFailure (Exception e ) {
180+ completionListener .onFailure (e );
181+ }
182+ };
130183 }
131184 }
132185
0 commit comments