2626import org .elasticsearch .xpack .inference .services .settings .RateLimitSettings ;
2727
2828import java .time .Clock ;
29+ import java .time .Duration ;
2930import java .time .Instant ;
3031import java .util .ArrayList ;
3132import java .util .List ;
@@ -125,7 +126,10 @@ interface RateLimiterCreator {
125126 private final AtomicInteger rateLimitDivisor = new AtomicInteger (1 );
126127 private final ThreadPool threadPool ;
127128 private final CountDownLatch startupLatch ;
128- private final CountDownLatch terminationLatch = new CountDownLatch (1 );
129+ // Two latches because we have two threads of execution, one thread blocking on a queue for items to be sent immediately, and
130+ // another threads that is scheduled on an interval that checks items that can be rate limited
131+ private final CountDownLatch immediateRequestQueueTerminationLatch = new CountDownLatch (1 );
132+ private final CountDownLatch rateLimitedTerminationLatch = new CountDownLatch (1 );
129133 private final RequestSender requestSender ;
130134 private final RequestExecutorServiceSettings settings ;
131135 private final Clock clock ;
@@ -184,11 +188,25 @@ public boolean isShutdown() {
184188 }
185189
186190 public boolean awaitTermination (long timeout , TimeUnit unit ) throws InterruptedException {
187- return terminationLatch .await (timeout , unit );
191+ var totalWait = Duration .ofMillis (unit .toMillis (timeout ));
192+
193+ var firstAwaitStart = Instant .now ();
194+ var firstLatchResult = immediateRequestQueueTerminationLatch .await (timeout , unit );
195+ var firstAwaitEnd = Instant .now ();
196+
197+ var remainingWaitTime = totalWait .minus (Duration .between (firstAwaitStart , firstAwaitEnd ));
198+
199+ // If the first latch await returns false, we've run out of time
200+ // If the remaining wait time is negative or zero, we've run out of time
201+ if (firstLatchResult == false || remainingWaitTime .isNegative () || remainingWaitTime .isZero ()) {
202+ return false ;
203+ }
204+
205+ return rateLimitedTerminationLatch .await (remainingWaitTime .toMillis (), TimeUnit .MILLISECONDS );
188206 }
189207
190208 public boolean isTerminated () {
191- return terminationLatch .getCount () == 0 ;
209+ return immediateRequestQueueTerminationLatch . getCount () == 0 && rateLimitedTerminationLatch .getCount () == 0 ;
192210 }
193211
194212 public int queueSize () {
@@ -213,6 +231,7 @@ public void start() {
213231 } catch (Exception e ) {
214232 logger .warn ("Failed to start request executor" , e );
215233 cleanup (CleanupStrategy .RATE_LIMITED_REQUEST_QUEUES_ONLY );
234+ cleanup (CleanupStrategy .REQUEST_QUEUE_ONLY );
216235 }
217236 }
218237
@@ -350,12 +369,16 @@ private void cleanup(CleanupStrategy cleanupStrategy) {
350369 shutdown ();
351370
352371 switch (cleanupStrategy ) {
353- case RATE_LIMITED_REQUEST_QUEUES_ONLY -> notifyRateLimitedRequestsOfShutdown ();
354- case REQUEST_QUEUE_ONLY -> rejectRequestsInRequestQueue ();
372+ case RATE_LIMITED_REQUEST_QUEUES_ONLY -> {
373+ notifyRateLimitedRequestsOfShutdown ();
374+ rateLimitedTerminationLatch .countDown ();
375+ }
376+ case REQUEST_QUEUE_ONLY -> {
377+ rejectRequestsInRequestQueue ();
378+ immediateRequestQueueTerminationLatch .countDown ();
379+ }
355380 default -> logger .error (Strings .format ("Unknown clean up strategy for request executor: [%s]" , cleanupStrategy .toString ()));
356381 }
357-
358- terminationLatch .countDown ();
359382 } catch (Exception e ) {
360383 logger .warn ("Encountered an error while cleaning up" , e );
361384 }
0 commit comments