diff --git a/docs/changelog/136167.yaml b/docs/changelog/136167.yaml new file mode 100644 index 0000000000000..6b63e41225ed3 --- /dev/null +++ b/docs/changelog/136167.yaml @@ -0,0 +1,6 @@ +pr: 136167 +summary: "[Inference API] Remove worst-case additional 50ms latency for non-rate limited\ + \ requests" +area: Machine Learning +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/InputType.java b/server/src/main/java/org/elasticsearch/inference/InputType.java index c930acdc0f45e..2c81db1bed116 100644 --- a/server/src/main/java/org/elasticsearch/inference/InputType.java +++ b/server/src/main/java/org/elasticsearch/inference/InputType.java @@ -61,6 +61,10 @@ public static boolean isInternalTypeOrUnspecified(InputType inputType) { return inputType == InputType.INTERNAL_INGEST || inputType == InputType.INTERNAL_SEARCH || inputType == InputType.UNSPECIFIED; } + public static boolean isIngest(InputType inputType) { + return inputType == InputType.INGEST || inputType == InputType.INTERNAL_INGEST; + } + public static boolean isSpecified(InputType inputType) { return inputType != null && inputType != InputType.UNSPECIFIED; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index 7138cd30aa4d1..4ce3b65866de1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue; @@ -33,6 +34,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -53,6 +55,29 @@ * {@link org.apache.http.client.methods.HttpUriRequest} to set a timeout for how long this executor will wait * attempting to execute a task (aka waiting for the connection manager to lease a connection). See * {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. + * + * The request flow looks as follows: + * + * -------------> Add request to fast-path request queue. + * | + * | + * request NOT supporting + * rate limiting + * | + * | + * Request ------------| + * | + * | + * request supporting + * rate limiting + * | + * | + * ------------> {Rate Limit Group 1 -> Queue 1, ..., Rate Limit Group N -> Queue N} + * + * Explanation: Submit request to the queue for the specific rate limiting group. + * The rate limiting groups are polled at the same specified interval, + * which in the worst cases introduces an additional latency of + * {@link RequestExecutorServiceSettings#getTaskPollFrequency()}. */ public class RequestExecutorService implements RequestExecutor { @@ -109,6 +134,8 @@ interface RateLimiterCreator { private final RateLimiterCreator rateLimiterCreator; private final AtomicReference cancellableCleanupTask = new AtomicReference<>(); private final AtomicBoolean started = new AtomicBoolean(false); + private final AdjustableCapacityBlockingQueue requestQueue; + private volatile Future requestQueueTask; public RequestExecutorService( ThreadPool threadPool, @@ -135,10 +162,16 @@ public RequestExecutorService( this.settings = Objects.requireNonNull(settings); this.clock = Objects.requireNonNull(clock); this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator); + this.requestQueue = new AdjustableCapacityBlockingQueue<>(queueCreator, settings.getQueueCapacity()); } public void shutdown() { if (shutdown.compareAndSet(false, true)) { + if (requestQueueTask != null) { + // Wakes up the queue in processRequestQueue + requestQueue.offer(NOOP_TASK); + } + if (cancellableCleanupTask.get() != null) { logger.debug(() -> "Stopping clean up thread"); cancellableCleanupTask.get().cancel(); @@ -159,7 +192,7 @@ public boolean isTerminated() { } public int queueSize() { - return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); + return requestQueue.size() + rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); } /** @@ -174,12 +207,12 @@ public void start() { started.set(true); startCleanupTask(); + startRequestQueueTask(); signalStartInitiated(); - - handleTasks(); + startHandlingRateLimitedTasks(); } catch (Exception e) { logger.warn("Failed to start request executor", e); - cleanup(); + cleanup(CleanupStrategy.RATE_LIMITED_REQUEST_QUEUES_ONLY); } } @@ -194,6 +227,11 @@ private void startCleanupTask() { cancellableCleanupTask.set(startCleanupThread(RATE_LIMIT_GROUP_CLEANUP_INTERVAL)); } + private void startRequestQueueTask() { + assert requestQueueTask == null : "The request queue can only be started once"; + requestQueueTask = threadPool.executor(UTILITY_THREAD_POOL_NAME).submit(this::processRequestQueue); + } + private Scheduler.Cancellable startCleanupThread(TimeValue interval) { logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", interval)); @@ -217,30 +255,119 @@ void removeStaleGroupings() { private void scheduleNextHandleTasks(TimeValue timeToWait) { if (shutdown.get()) { logger.debug("Shutdown requested while scheduling next handle task call, cleaning up"); - cleanup(); + cleanup(CleanupStrategy.RATE_LIMITED_REQUEST_QUEUES_ONLY); return; } - threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + threadPool.schedule(this::startHandlingRateLimitedTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } + + private void processRequestQueue() { + try { + while (isShutdown() == false) { + var task = requestQueue.take(); + + if (task == NOOP_TASK) { + if (isShutdown()) { + logger.debug("Shutdown requested, exiting request queue processing"); + break; + } + + // Skip processing NoopTask + continue; + } + + if (isShutdown()) { + logger.debug("Shutdown requested while handling request tasks, cleaning up"); + rejectNonRateLimitedRequest(task); + break; + } + + executeTaskImmediately(task); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.debug("Inference request queue interrupted, exiting"); + } catch (Exception e) { + logger.error("Unexpected error processing request queue, terminating", e); + } finally { + cleanup(CleanupStrategy.REQUEST_QUEUE_ONLY); + } } - private void cleanup() { + private void executeTaskImmediately(RejectableTask task) { + try { + task.getRequestManager() + .execute(task.getInferenceInputs(), requestSender, task.getRequestCompletedFunction(), task.getListener()); + } catch (Exception e) { + logger.warn( + format("Failed to execute fast-path request for inference id [%s]", task.getRequestManager().inferenceEntityId()), + e + ); + + task.onRejection( + new EsRejectedExecutionException( + format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()), + false + ) + ); + } + } + + // visible for testing + void submitTaskToRateLimitedExecutionPath(RequestTask task) { + var requestManager = task.getRequestManager(); + var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> { + var endpointHandler = new RateLimitingEndpointHandler( + Integer.toString(requestManager.rateLimitGrouping().hashCode()), + queueCreator, + settings, + requestSender, + clock, + requestManager.rateLimitSettings(), + this::isShutdown, + rateLimiterCreator, + rateLimitDivisor.get() + ); + + endpointHandler.init(); + return endpointHandler; + }); + + endpoint.enqueue(task); + } + + private static boolean isEmbeddingsIngestInput(InferenceInputs inputs) { + return inputs instanceof EmbeddingsInput embeddingsInput && InputType.isIngest(embeddingsInput.getInputType()); + } + + private static boolean rateLimitingEnabled(RateLimitSettings rateLimitSettings) { + return rateLimitSettings != null && rateLimitSettings.isEnabled(); + } + + private void cleanup(CleanupStrategy cleanupStrategy) { try { shutdown(); - notifyRequestsOfShutdown(); + + switch (cleanupStrategy) { + case RATE_LIMITED_REQUEST_QUEUES_ONLY -> notifyRateLimitedRequestsOfShutdown(); + case REQUEST_QUEUE_ONLY -> rejectRequestsInRequestQueue(); + default -> logger.error(Strings.format("Unknown clean up strategy for request executor: [%s]", cleanupStrategy.toString())); + } + terminationLatch.countDown(); } catch (Exception e) { logger.warn("Encountered an error while cleaning up", e); } } - private void handleTasks() { + private void startHandlingRateLimitedTasks() { try { TimeValue timeToWait; do { - if (shutdown.get()) { - logger.debug("Shutdown requested while handling tasks, cleaning up"); - cleanup(); + if (isShutdown()) { + logger.debug("Shutdown requested while handling rate limited tasks, cleaning up"); + cleanup(CleanupStrategy.RATE_LIMITED_REQUEST_QUEUES_ONLY); return; } @@ -253,12 +380,12 @@ private void handleTasks() { scheduleNextHandleTasks(timeToWait); } catch (Exception e) { - logger.warn("Encountered an error while handling tasks", e); - cleanup(); + logger.warn("Encountered an error while handling rate limited tasks", e); + cleanup(CleanupStrategy.RATE_LIMITED_REQUEST_QUEUES_ONLY); } } - private void notifyRequestsOfShutdown() { + private void notifyRateLimitedRequestsOfShutdown() { assert isShutdown() : "Requests should only be notified if the executor is shutting down"; for (var endpoint : rateLimitGroupings.values()) { @@ -266,6 +393,41 @@ private void notifyRequestsOfShutdown() { } } + private void rejectRequestsInRequestQueue() { + assert isShutdown() : "Requests in request queue should only be notified if the executor is shutting down"; + + List requests = new ArrayList<>(); + requestQueue.drainTo(requests); + + for (var request : requests) { + // NoopTask does not implement being rejected, therefore we need to skip it + if (request != NOOP_TASK) { + rejectNonRateLimitedRequest(request); + } + } + } + + private void rejectNonRateLimitedRequest(RejectableTask task) { + var inferenceEntityId = task.getRequestManager().inferenceEntityId(); + + rejectRequest( + task, + format( + "Failed to send request for inference id [%s] because the request executor service has been shutdown", + inferenceEntityId + ), + format("Failed to notify request for inference id [%s] of rejection after executor service shutdown", inferenceEntityId) + ); + } + + private static void rejectRequest(RejectableTask task, String rejectionMessage, String rejectionFailedMessage) { + try { + task.onRejection(new EsRejectedExecutionException(rejectionMessage, true)); + } catch (Exception e) { + logger.warn(rejectionFailedMessage); + } + } + // default for testing Integer remainingQueueCapacity(RequestManager requestManager) { var endpoint = rateLimitGroupings.get(requestManager.rateLimitGrouping()); @@ -308,26 +470,33 @@ public void execute( ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()) ); - var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> { - var endpointHandler = new RateLimitingEndpointHandler( - Integer.toString(requestManager.rateLimitGrouping().hashCode()), - queueCreator, - settings, - requestSender, - clock, - requestManager.rateLimitSettings(), - this::isShutdown, - rateLimiterCreator, - rateLimitDivisor.get() + if (isShutdown()) { + task.onRejection( + new EsRejectedExecutionException( + format( + "Failed to enqueue request task for inference id [%s] because the request executor service has been shutdown", + requestManager.inferenceEntityId() + ), + true + ) ); + return; + } - // TODO: add or create/compute if absent set for new map (service/task-type-key -> rate limit endpoint handler) - - endpointHandler.init(); - return endpointHandler; - }); + if (isEmbeddingsIngestInput(inferenceInputs) || rateLimitingEnabled(requestManager.rateLimitSettings())) { + submitTaskToRateLimitedExecutionPath(task); + } else { + boolean taskAccepted = requestQueue.offer(task); - endpoint.enqueue(task); + if (taskAccepted == false) { + task.onRejection( + new EsRejectedExecutionException( + format("Failed to enqueue request task for inference id [%s]", requestManager.inferenceEntityId()), + false + ) + ); + } + } } /** @@ -345,7 +514,7 @@ static class RateLimitingEndpointHandler { private final AdjustableCapacityBlockingQueue queue; private final Supplier isShutdownMethod; private final RequestSender requestSender; - private final String id; + private final String rateLimitGroupingId; private final AtomicReference timeOfLastEnqueue = new AtomicReference<>(); private final Clock clock; private final RateLimiter rateLimiter; @@ -354,7 +523,7 @@ static class RateLimitingEndpointHandler { private final Long originalRequestsPerTimeUnit; RateLimitingEndpointHandler( - String id, + String rateLimitGroupingId, AdjustableCapacityBlockingQueue.QueueCreator createQueue, RequestExecutorServiceSettings settings, RequestSender requestSender, @@ -365,7 +534,7 @@ static class RateLimitingEndpointHandler { Integer rateLimitDivisor ) { this.requestExecutorServiceSettings = Objects.requireNonNull(settings); - this.id = Objects.requireNonNull(id); + this.rateLimitGroupingId = Objects.requireNonNull(rateLimitGroupingId); this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity()); this.requestSender = Objects.requireNonNull(requestSender); this.clock = Objects.requireNonNull(clock); @@ -383,20 +552,25 @@ static class RateLimitingEndpointHandler { } public void init() { - requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange); - } - - public String id() { - return id; + requestExecutorServiceSettings.registerQueueCapacityCallback(rateLimitGroupingId, this::onCapacityChange); } private void onCapacityChange(int capacity) { - logger.debug(() -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", id, capacity)); + logger.debug( + () -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", rateLimitGroupingId, capacity) + ); try { queue.setCapacity(capacity); } catch (Exception e) { - logger.warn(format("Executor service grouping [%s] failed to set the capacity of the task queue to [%s]", id, capacity), e); + logger.warn( + format( + "Executor service grouping [%s] failed to set the capacity of the task queue to [%s]", + rateLimitGroupingId, + capacity + ), + e + ); } } @@ -416,7 +590,7 @@ public synchronized TimeValue executeEnqueuedTask() { try { return executeEnqueuedTaskInternal(); } catch (Exception e) { - logger.warn(format("Executor service grouping [%s] failed to execute request", id), e); + logger.warn(format("Executor service grouping [%s] failed to execute request", rateLimitGroupingId), e); // we tried to do some work but failed, so we'll say we did something to try looking for more work return EXECUTED_A_TASK; } @@ -472,7 +646,7 @@ public void enqueue(RequestTask task) { format( "Failed to enqueue task for inference id [%s] because the request service [%s] has already shutdown", task.getRequestManager().inferenceEntityId(), - id + rateLimitGroupingId ), true ); @@ -488,7 +662,7 @@ public void enqueue(RequestTask task) { format( "Failed to execute task for inference id [%s] because the request service [%s] queue is full", task.getRequestManager().inferenceEntityId(), - id + rateLimitGroupingId ), false ); @@ -508,34 +682,25 @@ public synchronized void notifyRequestsOfShutdown() { rejectTasks(notExecuted); } catch (Exception e) { - logger.warn(format("Failed to notify tasks of executor service grouping [%s] shutdown", id)); + logger.warn(format("Failed to notify tasks of executor service grouping [%s] shutdown", rateLimitGroupingId)); } } private void rejectTasks(List tasks) { for (var task : tasks) { - rejectTaskForShutdown(task); - } - } + var inferenceEntityId = task.getRequestManager().inferenceEntityId(); - private void rejectTaskForShutdown(RejectableTask task) { - try { - task.onRejection( - new EsRejectedExecutionException( - format( - "Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request", - id, - task.getRequestManager().inferenceEntityId() - ), - true - ) - ); - } catch (Exception e) { - logger.warn( + rejectRequest( + task, + format( + "Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request", + rateLimitGroupingId, + inferenceEntityId + ), format( "Failed to notify request for inference id [%s] of rejection after executor service grouping [%s] shutdown", - task.getRequestManager().inferenceEntityId(), - id + inferenceEntityId, + rateLimitGroupingId ) ); } @@ -546,7 +711,44 @@ public int remainingCapacity() { } public void close() { - requestExecutorServiceSettings.deregisterQueueCapacityCallback(id); + requestExecutorServiceSettings.deregisterQueueCapacityCallback(rateLimitGroupingId); + } + } + + private static final RejectableTask NOOP_TASK = new RejectableTask() { + @Override + public void onRejection(Exception e) { + throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); + } + + @Override + public RequestManager getRequestManager() { + throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); + } + + @Override + public InferenceInputs getInferenceInputs() { + throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); + } + + @Override + public ActionListener getListener() { + throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); + } + + @Override + public boolean hasCompleted() { + throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); + } + + @Override + public Supplier getRequestCompletedFunction() { + throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); } + }; + + private enum CleanupStrategy { + REQUEST_QUEUE_ONLY, + RATE_LIMITED_REQUEST_QUEUES_ONLY } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java index c66bc8d33c05b..2bf687a8992da 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java @@ -125,7 +125,7 @@ public RateLimitSettings(long requestsPerTimeUnit, TimeUnit timeUnit) { } // This should only be used for testing. - RateLimitSettings(long requestsPerTimeUnit, TimeUnit timeUnit, boolean enabled) { + public RateLimitSettings(long requestsPerTimeUnit, TimeUnit timeUnit, boolean enabled) { if (requestsPerTimeUnit <= 0) { throw new IllegalArgumentException("requests per minute must be positive"); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java index c6e08a8d5bdf3..9f2055d825fd7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java @@ -64,6 +64,10 @@ public static InputType randomWithInternalAndUnspecified() { return randomFrom(InputType.INTERNAL_SEARCH, InputType.INTERNAL_INGEST, InputType.UNSPECIFIED); } + public static InputType randomIngest() { + return randomFrom(InputType.INGEST, InputType.INTERNAL_INGEST); + } + public void testFromRestString_ValidInputType() { for (String internal : List.of("search", "ingest", "classification", "clustering", "unspecified")) { assertEquals(InputType.fromRestString(internal), InputType.fromString(internal)); @@ -211,4 +215,8 @@ public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIs ) ); } + + public void testIsIngest() { + assertTrue(InputType.isIngest(randomFrom(InputType.INGEST, InputType.INTERNAL_INGEST))); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 027a19aca6d1f..2cde70720dabe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -354,7 +354,12 @@ public void testHttpRequestSender_Throws_WhenCallingSendBeforeStart() throws Exc PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( AssertionError.class, - () -> sender.send(RequestManagerTests.createMock(), new EmbeddingsInput(List.of(), null), null, listener) + () -> sender.send( + RequestManagerTests.createMockWithRateLimitingEnabled(), + new EmbeddingsInput(List.of(), null), + null, + listener + ) ); assertThat(thrownException.getMessage(), is("call start() before sending a request")); } @@ -375,7 +380,12 @@ public void testHttpRequestSender_Throws_WhenATimeoutOccurs() throws Exception { sender.startSynchronously(); PlainActionFuture listener = new PlainActionFuture<>(); - sender.send(RequestManagerTests.createMock(), new EmbeddingsInput(List.of(), null), TimeValue.timeValueNanos(1), listener); + sender.send( + RequestManagerTests.createMockWithRateLimitingEnabled(), + new EmbeddingsInput(List.of(), null), + TimeValue.timeValueNanos(1), + listener + ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -397,7 +407,12 @@ public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws sender.startSynchronously(); PlainActionFuture listener = new PlainActionFuture<>(); - sender.send(RequestManagerTests.createMock(), new EmbeddingsInput(List.of(), null), TimeValue.timeValueNanos(1), listener); + sender.send( + RequestManagerTests.createMockWithRateLimitingEnabled(), + new EmbeddingsInput(List.of(), null), + TimeValue.timeValueNanos(1), + listener + ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java index 163c4b84f1780..3cfecd5a9cd2e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java @@ -34,17 +34,15 @@ import java.time.Duration; import java.time.Instant; import java.util.List; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; -import static org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueueTests.mockQueueCreator; import static org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettingsTests.createRequestExecutorServiceSettings; import static org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettingsTests.createRequestExecutorServiceSettingsEmpty; import static org.hamcrest.Matchers.instanceOf; @@ -82,7 +80,7 @@ public void testQueueSize_IsEmpty() { public void testQueueSize_IsOne() { var service = createRequestExecutorServiceWithMocks(); service.execute( - RequestManagerTests.createMock(), + RequestManagerTests.createMockWithRateLimitingEnabled(), new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, new PlainActionFuture<>() @@ -161,7 +159,7 @@ public void testExecute_AfterShutdown_Throws() { service.shutdown(); - var requestManager = RequestManagerTests.createMock("id"); + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled("id"); var listener = new PlainActionFuture(); service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener); @@ -171,7 +169,7 @@ public void testExecute_AfterShutdown_Throws() { thrownException.getMessage(), is( Strings.format( - "Failed to enqueue task for inference id [id] because the request service [%s] has already shutdown", + "Failed to enqueue request task for inference id [id] because the request executor service has been shutdown", requestManager.rateLimitGrouping().hashCode() ) ) @@ -179,19 +177,20 @@ public void testExecute_AfterShutdown_Throws() { assertTrue(thrownException.isExecutorShutdown()); } - public void testExecute_Throws_WhenQueueIsFull() { + public void testExecute_Throws_WhenRateLimitedQueueIsFull() { var service = new RequestExecutorService(threadPool, null, createRequestExecutorServiceSettings(1), mock(RetryingHttpSender.class)); + service.start(); service.execute( - RequestManagerTests.createMock(), - new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), + RequestManagerTests.createMockWithRateLimitingEnabled(), + new EmbeddingsInput(List.of(), InputTypeTests.randomIngest()), null, new PlainActionFuture<>() ); - var requestManager = RequestManagerTests.createMock("id"); + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled("id"); var listener = new PlainActionFuture(); - service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener); + service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomIngest()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); @@ -251,7 +250,7 @@ public void testExecute_CallsOnFailure_WhenRequestTimesOut() { var listener = new PlainActionFuture(); service.execute( - RequestManagerTests.createMock(), + RequestManagerTests.createMockWithRateLimitingEnabled(), new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), TimeValue.timeValueNanos(1), listener @@ -312,7 +311,7 @@ public void onFailure(Exception e) { }; service.execute( - RequestManagerTests.createMock(requestSender), + RequestManagerTests.createMockWithRateLimitingEnabled(requestSender), new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener @@ -326,10 +325,10 @@ public void onFailure(Exception e) { finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); } - public void testExecute_NotifiesTasksOfShutdown() { + public void testExecute_NotifiesNonRateLimitedTasksOfShutdown() { var service = createRequestExecutorServiceWithMocks(); - var requestManager = RequestManagerTests.createMock(mock(RequestSender.class), "id"); + var requestManager = RequestManagerTests.createMockWithRateLimitingDisabled(mock(RequestSender.class), "id"); var listener = new PlainActionFuture(); service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener); @@ -338,6 +337,28 @@ public void testExecute_NotifiesTasksOfShutdown() { var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to send request for inference id [id] because the request executor service has been shutdown") + ); + assertTrue(thrownException.isExecutorShutdown()); + assertTrue(service.isTerminated()); + } + + public void testExecute_NotifiesRateLimitedTasksOfShutdown() { + var service = createRequestExecutorServiceWithMocks(); + + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled(mock(RequestSender.class), "id"); + var listener = new PlainActionFuture(); + service.submitTaskToRateLimitedExecutionPath( + new RequestTask(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, threadPool, listener) + ); + + service.shutdown(); + service.start(); + + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( thrownException.getMessage(), is( @@ -351,33 +372,34 @@ public void testExecute_NotifiesTasksOfShutdown() { assertTrue(service.isTerminated()); } - public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException { - @SuppressWarnings("unchecked") - BlockingQueue queue = mock(LinkedBlockingQueue.class); - + public void testTask_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException { var requestSender = mock(RetryingHttpSender.class); + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled(requestSender, "id"); + CountDownLatch taskProcessedLatch = new CountDownLatch(1); - var service = new RequestExecutorService( - threadPool, - mockQueueCreator(queue), - null, - createRequestExecutorServiceSettingsEmpty(), - requestSender, - Clock.systemUTC(), - RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR - ); + doAnswer(invocation -> { + taskProcessedLatch.countDown(); + throw new ElasticsearchException("failed"); + }).when(requestManager).execute(any(), any(), any(), any()); + + var service = new RequestExecutorService(threadPool, null, createRequestExecutorServiceSettingsEmpty(), requestSender); + + service.start(); PlainActionFuture listener = new PlainActionFuture<>(); - var requestManager = RequestManagerTests.createMock(requestSender, "id"); service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener); - when(queue.poll()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> { - service.shutdown(); - return null; - }); - service.start(); + // Wait for throwing task to be executed + assertTrue(taskProcessedLatch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS)); + + // Make sure service is still running after processing a task, which threw an Exception + assertFalse(service.isShutdown()); + assertFalse(service.isTerminated()); + + service.shutdown(); service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + assertTrue(service.isShutdown()); assertTrue(service.isTerminated()); } @@ -388,26 +410,21 @@ public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, var service = new RequestExecutorService(threadPool, null, settings, requestSender); service.execute( - RequestManagerTests.createMock(requestSender), - new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), + RequestManagerTests.createMockWithRateLimitingEnabled(requestSender), + new EmbeddingsInput(List.of(), InputTypeTests.randomIngest()), null, new PlainActionFuture<>() ); assertThat(service.queueSize(), is(1)); PlainActionFuture listener = new PlainActionFuture<>(); - var requestManager = RequestManagerTests.createMock(requestSender, "id"); - service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener); + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled(requestSender, "id"); + service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomIngest()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is( - Strings.format( - "Failed to execute task for inference id [id] because the request service [%s] queue is full", - requestManager.rateLimitGrouping().hashCode() - ) - ) + is("Failed to execute task for inference id [id] because the request service [3355] queue is full") ); settings.setQueueCapacity(2); @@ -437,22 +454,30 @@ public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull( var settings = createRequestExecutorServiceSettings(3); var service = new RequestExecutorService(threadPool, null, settings, requestSender); - service.execute( - RequestManagerTests.createMock(requestSender, "id"), - new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), - null, - new PlainActionFuture<>() + service.submitTaskToRateLimitedExecutionPath( + new RequestTask( + RequestManagerTests.createMockWithRateLimitingEnabled(requestSender, "id"), + new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), + null, + threadPool, + new PlainActionFuture<>() + ) ); - service.execute( - RequestManagerTests.createMock(requestSender, "id"), - new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), - null, - new PlainActionFuture<>() + service.submitTaskToRateLimitedExecutionPath( + new RequestTask( + RequestManagerTests.createMockWithRateLimitingEnabled(requestSender, "id"), + new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), + null, + threadPool, + new PlainActionFuture<>() + ) ); PlainActionFuture listener = new PlainActionFuture<>(); - var requestManager = RequestManagerTests.createMock(requestSender, "id"); - service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener); + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled(requestSender, "id"); + service.submitTaskToRateLimitedExecutionPath( + new RequestTask(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, threadPool, listener) + ); assertThat(service.queueSize(), is(3)); settings.setQueueCapacity(1); @@ -497,15 +522,15 @@ public void testChangingCapacity_ToZero_SetsQueueCapacityToUnbounded() throws IO var settings = createRequestExecutorServiceSettings(1); var service = new RequestExecutorService(threadPool, null, settings, requestSender); - var requestManager = RequestManagerTests.createMock(requestSender); + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled(requestSender); - service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, new PlainActionFuture<>()); + service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomIngest()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); PlainActionFuture listener = new PlainActionFuture<>(); service.execute( - RequestManagerTests.createMock(requestSender, "id"), - new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), + RequestManagerTests.createMockWithRateLimitingEnabled(requestSender, "id"), + new EmbeddingsInput(List.of(), InputTypeTests.randomIngest()), null, listener ); @@ -556,7 +581,7 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens() { Clock.systemUTC(), rateLimiterCreator ); - var requestManager = RequestManagerTests.createMock(requestSender); + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled(requestSender); PlainActionFuture listener = new PlainActionFuture<>(); service.execute(requestManager, new EmbeddingsInput(List.of(), null), null, listener); @@ -628,10 +653,10 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And Clock.systemUTC(), rateLimiterCreator ); - var requestManager = RequestManagerTests.createMock(requestSender); + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled(requestSender); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener); + service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomIngest()), null, listener); when(mockRateLimiter.timeToReserve(anyInt())).thenReturn(TimeValue.timeValueDays(1)).thenReturn(TimeValue.timeValueDays(0)); @@ -665,10 +690,12 @@ public void testRemovesRateLimitGroup_AfterStaleDuration() { clock, RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR ); - var requestManager = RequestManagerTests.createMock(requestSender, "id1"); + var requestManager = RequestManagerTests.createMockWithRateLimitingEnabled(requestSender, "id1"); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener); + service.submitTaskToRateLimitedExecutionPath( + new RequestTask(requestManager, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, threadPool, listener) + ); assertThat(service.numberOfRateLimitGroups(), is(1)); // the time is moved to after the stale duration, so now we should remove this grouping @@ -676,8 +703,10 @@ public void testRemovesRateLimitGroup_AfterStaleDuration() { service.removeStaleGroupings(); assertThat(service.numberOfRateLimitGroups(), is(0)); - var requestManager2 = RequestManagerTests.createMock(requestSender, "id2"); - service.execute(requestManager2, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, listener); + var requestManager2 = RequestManagerTests.createMockWithRateLimitingEnabled(requestSender, "id2"); + service.submitTaskToRateLimitedExecutionPath( + new RequestTask(requestManager2, new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()), null, threadPool, listener) + ); assertThat(service.numberOfRateLimitGroups(), is(1)); } @@ -707,6 +736,31 @@ public void testStartsCleanupThread() { assertThat(argument.getValue(), is(TimeValue.timeValueDays(1))); } + public void testStartsRequestQueueTask() { + var mockExecutorService = mock(ExecutorService.class); + when(mockExecutorService.submit(any(Runnable.class))).thenAnswer(i -> mock(Future.class)); + + var mockThreadPool = mock(ThreadPool.class); + when(mockThreadPool.executor(any())).thenReturn(mockExecutorService); + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1)); + var service = new RequestExecutorService( + mockThreadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + Clock.systemUTC(), + RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR + ); + + service.shutdown(); + service.start(); + + verify(mockExecutorService, times(1)).submit(any(Runnable.class)); + } + private Future submitShutdownRequest( CountDownLatch waitToShutdown, CountDownLatch waitToReturnFromSend, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java index d8a1f2c4227e4..773590fb0b0c6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java @@ -15,26 +15,36 @@ import org.elasticsearch.xpack.inference.external.request.RequestTests; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.util.concurrent.TimeUnit; + import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class RequestManagerTests { - public static RequestManager createMock() { - return createMock(mock(RequestSender.class)); + public static RequestManager createMockWithRateLimitingDisabled(RequestSender requestSender, String inferenceEntityId) { + return createMock(requestSender, inferenceEntityId, new RateLimitSettings(1, TimeUnit.MINUTES, false)); + } + + public static RequestManager createMockWithRateLimitingDisabled(String inferenceEntityId) { + return createMock(mock(RequestSender.class), inferenceEntityId, new RateLimitSettings(1, TimeUnit.MINUTES, false)); + } + + public static RequestManager createMockWithRateLimitingEnabled() { + return createMockWithRateLimitingEnabled(mock(RequestSender.class)); } - public static RequestManager createMock(String inferenceEntityId) { - return createMock(mock(RequestSender.class), inferenceEntityId); + public static RequestManager createMockWithRateLimitingEnabled(String inferenceEntityId) { + return createMockWithRateLimitingEnabled(mock(RequestSender.class), inferenceEntityId); } - public static RequestManager createMock(RequestSender requestSender) { - return createMock(requestSender, "id", new RateLimitSettings(1)); + public static RequestManager createMockWithRateLimitingEnabled(RequestSender requestSender) { + return createMock(requestSender, "id", new RateLimitSettings(1, TimeUnit.MINUTES, true)); } - public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId) { - return createMock(requestSender, inferenceEntityId, new RateLimitSettings(1)); + public static RequestManager createMockWithRateLimitingEnabled(RequestSender requestSender, String inferenceEntityId) { + return createMock(requestSender, inferenceEntityId, new RateLimitSettings(1, TimeUnit.MINUTES, true)); } public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId, RateLimitSettings settings) {