diff --git a/docs/changelog/126858.yaml b/docs/changelog/126858.yaml new file mode 100644 index 0000000000000..d1ea2ebba73ef --- /dev/null +++ b/docs/changelog/126858.yaml @@ -0,0 +1,6 @@ +pr: 126858 +summary: Leverage threadpool schedule for inference api to avoid long running thread +area: Machine Learning +type: bug +issues: + - 126853 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index 9bcd8e3dba44e..e3fff14bf95d7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -57,15 +57,6 @@ */ public class RequestExecutorService implements RequestExecutor { - /** - * Provides dependency injection mainly for testing - */ - interface Sleeper { - void sleep(TimeValue sleepTime) throws InterruptedException; - } - - // default for tests - static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration()); // default for tests static final AdjustableCapacityBlockingQueue.QueueCreator DEFAULT_QUEUE_CREATOR = new AdjustableCapacityBlockingQueue.QueueCreator<>() { @@ -118,7 +109,6 @@ interface RateLimiterCreator { private final Clock clock; private final AtomicBoolean shutdown = new AtomicBoolean(false); private final AdjustableCapacityBlockingQueue.QueueCreator queueCreator; - private final Sleeper sleeper; private final RateLimiterCreator rateLimiterCreator; private final AtomicReference cancellableCleanupTask = new AtomicReference<>(); private final AtomicBoolean started = new AtomicBoolean(false); @@ -129,16 +119,7 @@ public RequestExecutorService( RequestExecutorServiceSettings settings, RequestSender requestSender ) { - this( - threadPool, - DEFAULT_QUEUE_CREATOR, - startupLatch, - settings, - requestSender, - Clock.systemUTC(), - DEFAULT_SLEEPER, - DEFAULT_RATE_LIMIT_CREATOR - ); + this(threadPool, DEFAULT_QUEUE_CREATOR, startupLatch, settings, requestSender, Clock.systemUTC(), DEFAULT_RATE_LIMIT_CREATOR); } public RequestExecutorService( @@ -148,7 +129,6 @@ public RequestExecutorService( RequestExecutorServiceSettings settings, RequestSender requestSender, Clock clock, - Sleeper sleeper, RateLimiterCreator rateLimiterCreator ) { this.threadPool = Objects.requireNonNull(threadPool); @@ -157,7 +137,6 @@ public RequestExecutorService( this.requestSender = Objects.requireNonNull(requestSender); this.settings = Objects.requireNonNull(settings); this.clock = Objects.requireNonNull(clock); - this.sleeper = Objects.requireNonNull(sleeper); this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator); } @@ -213,15 +192,10 @@ public void start() { startCleanupTask(); signalStartInitiated(); - while (isShutdown() == false) { - handleTasks(); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } finally { - shutdown(); - notifyRequestsOfShutdown(); - terminationLatch.countDown(); + handleTasks(); + } catch (Exception e) { + logger.warn("Failed to start request executor", e); + cleanup(); } } @@ -256,13 +230,44 @@ void removeStaleGroupings() { } } - private void handleTasks() throws InterruptedException { - var timeToWait = settings.getTaskPollFrequency(); - for (var endpoint : rateLimitGroupings.values()) { - timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); + private void scheduleNextHandleTasks(TimeValue timeToWait) { + if (shutdown.get()) { + logger.debug("Shutdown requested while scheduling next handle task call, cleaning up"); + cleanup(); + return; + } + + threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } + + private void cleanup() { + try { + shutdown(); + notifyRequestsOfShutdown(); + terminationLatch.countDown(); + } catch (Exception e) { + logger.warn("Encountered an error while cleaning up", e); } + } - sleeper.sleep(timeToWait); + private void handleTasks() { + try { + if (shutdown.get()) { + logger.debug("Shutdown requested while handling tasks, cleaning up"); + cleanup(); + return; + } + + var timeToWait = settings.getTaskPollFrequency(); + for (var endpoint : rateLimitGroupings.values()) { + timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); + } + + scheduleNextHandleTasks(timeToWait); + } catch (Exception e) { + logger.warn("Encountered an error while handling tasks", e); + cleanup(); + } } private void notifyRequestsOfShutdown() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 2c78b7358e9ba..81be52bc567e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -50,6 +50,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER; @@ -90,7 +91,7 @@ public void shutdown() throws IOException, InterruptedException { } public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { - var senderFactory = createSenderFactory(clientManager, threadRef); + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = createSender(senderFactory)) { sender.start(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java index 114d20b80590a..b1525b82a2381 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java @@ -51,7 +51,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -206,7 +205,7 @@ public void testExecute_Throws_WhenQueueIsFull() { assertFalse(thrownException.isExecutorShutdown()); } - public void testTaskThrowsError_CallsOnFailure() { + public void testTaskThrowsError_CallsOnFailure() throws InterruptedException { var requestSender = mock(RetryingHttpSender.class); var service = createRequestExecutorService(null, requestSender); @@ -229,6 +228,8 @@ public void testTaskThrowsError_CallsOnFailure() { var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id"))); assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class)); + service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + assertTrue(service.isTerminated()); } @@ -361,7 +362,6 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I createRequestExecutorServiceSettingsEmpty(), requestSender, Clock.systemUTC(), - RequestExecutorService.DEFAULT_SLEEPER, RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR ); @@ -375,36 +375,7 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I }); service.start(); - assertTrue(service.isTerminated()); - } - - public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception { - @SuppressWarnings("unchecked") - BlockingQueue queue = mock(LinkedBlockingQueue.class); - var sleeper = mock(RequestExecutorService.Sleeper.class); - doThrow(new InterruptedException("failed")).when(sleeper).sleep(any()); - - var service = new RequestExecutorService( - threadPool, - mockQueueCreator(queue), - null, - createRequestExecutorServiceSettingsEmpty(), - mock(RetryingHttpSender.class), - Clock.systemUTC(), - sleeper, - RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR - ); - - Future executorTermination = threadPool.generic().submit(() -> { - try { - service.start(); - } catch (Exception e) { - fail(Strings.format("Failed to shutdown executor: %s", e)); - } - }); - - executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); - + service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS); assertTrue(service.isTerminated()); } @@ -581,7 +552,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens() { settings, requestSender, Clock.systemUTC(), - RequestExecutorService.DEFAULT_SLEEPER, rateLimiterCreator ); var requestManager = RequestManagerTests.createMock(requestSender); @@ -614,7 +584,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And settings, requestSender, Clock.systemUTC(), - RequestExecutorService.DEFAULT_SLEEPER, rateLimiterCreator ); var requestManager = RequestManagerTests.createMock(requestSender); @@ -626,11 +595,15 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And doAnswer(invocation -> { service.shutdown(); + ActionListener passedListener = invocation.getArgument(4); + passedListener.onResponse(null); + return Void.TYPE; }).when(requestSender).send(any(), any(), any(), any(), any()); service.start(); + listener.actionGet(TIMEOUT); verify(requestSender, times(1)).send(any(), any(), any(), any(), any()); } @@ -648,7 +621,6 @@ public void testRemovesRateLimitGroup_AfterStaleDuration() { settings, requestSender, clock, - RequestExecutorService.DEFAULT_SLEEPER, RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR ); var requestManager = RequestManagerTests.createMock(requestSender, "id1"); @@ -682,7 +654,6 @@ public void testStartsCleanupThread() { settings, requestSender, Clock.systemUTC(), - RequestExecutorService.DEFAULT_SLEEPER, RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR );