Skip to content

Commit 5bd7f7f

Browse files
[ML] Refactor inference request executor to leverage scheduled execution (elastic#126858)
* Using threadpool schedule and fixing tests * Update docs/changelog/126858.yaml * Clean up * change log
1 parent 6763e06 commit 5bd7f7f

File tree

4 files changed

+57
-74
lines changed

4 files changed

+57
-74
lines changed

docs/changelog/126858.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 126858
2+
summary: Leverage threadpool schedule for inference api to avoid long running thread
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 126853

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@
5757
*/
5858
public class RequestExecutorService implements RequestExecutor {
5959

60-
/**
61-
* Provides dependency injection mainly for testing
62-
*/
63-
interface Sleeper {
64-
void sleep(TimeValue sleepTime) throws InterruptedException;
65-
}
66-
67-
// default for tests
68-
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
6960
// default for tests
7061
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
7162
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
@@ -118,7 +109,6 @@ interface RateLimiterCreator {
118109
private final Clock clock;
119110
private final AtomicBoolean shutdown = new AtomicBoolean(false);
120111
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
121-
private final Sleeper sleeper;
122112
private final RateLimiterCreator rateLimiterCreator;
123113
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
124114
private final AtomicBoolean started = new AtomicBoolean(false);
@@ -129,16 +119,7 @@ public RequestExecutorService(
129119
RequestExecutorServiceSettings settings,
130120
RequestSender requestSender
131121
) {
132-
this(
133-
threadPool,
134-
DEFAULT_QUEUE_CREATOR,
135-
startupLatch,
136-
settings,
137-
requestSender,
138-
Clock.systemUTC(),
139-
DEFAULT_SLEEPER,
140-
DEFAULT_RATE_LIMIT_CREATOR
141-
);
122+
this(threadPool, DEFAULT_QUEUE_CREATOR, startupLatch, settings, requestSender, Clock.systemUTC(), DEFAULT_RATE_LIMIT_CREATOR);
142123
}
143124

144125
public RequestExecutorService(
@@ -148,7 +129,6 @@ public RequestExecutorService(
148129
RequestExecutorServiceSettings settings,
149130
RequestSender requestSender,
150131
Clock clock,
151-
Sleeper sleeper,
152132
RateLimiterCreator rateLimiterCreator
153133
) {
154134
this.threadPool = Objects.requireNonNull(threadPool);
@@ -157,7 +137,6 @@ public RequestExecutorService(
157137
this.requestSender = Objects.requireNonNull(requestSender);
158138
this.settings = Objects.requireNonNull(settings);
159139
this.clock = Objects.requireNonNull(clock);
160-
this.sleeper = Objects.requireNonNull(sleeper);
161140
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
162141
}
163142

@@ -213,15 +192,10 @@ public void start() {
213192
startCleanupTask();
214193
signalStartInitiated();
215194

216-
while (isShutdown() == false) {
217-
handleTasks();
218-
}
219-
} catch (InterruptedException e) {
220-
Thread.currentThread().interrupt();
221-
} finally {
222-
shutdown();
223-
notifyRequestsOfShutdown();
224-
terminationLatch.countDown();
195+
handleTasks();
196+
} catch (Exception e) {
197+
logger.warn("Failed to start request executor", e);
198+
cleanup();
225199
}
226200
}
227201

@@ -256,13 +230,44 @@ void removeStaleGroupings() {
256230
}
257231
}
258232

259-
private void handleTasks() throws InterruptedException {
260-
var timeToWait = settings.getTaskPollFrequency();
261-
for (var endpoint : rateLimitGroupings.values()) {
262-
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
233+
private void scheduleNextHandleTasks(TimeValue timeToWait) {
234+
if (shutdown.get()) {
235+
logger.debug("Shutdown requested while scheduling next handle task call, cleaning up");
236+
cleanup();
237+
return;
238+
}
239+
240+
threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME));
241+
}
242+
243+
private void cleanup() {
244+
try {
245+
shutdown();
246+
notifyRequestsOfShutdown();
247+
terminationLatch.countDown();
248+
} catch (Exception e) {
249+
logger.warn("Encountered an error while cleaning up", e);
263250
}
251+
}
264252

265-
sleeper.sleep(timeToWait);
253+
private void handleTasks() {
254+
try {
255+
if (shutdown.get()) {
256+
logger.debug("Shutdown requested while handling tasks, cleaning up");
257+
cleanup();
258+
return;
259+
}
260+
261+
var timeToWait = settings.getTaskPollFrequency();
262+
for (var endpoint : rateLimitGroupings.values()) {
263+
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
264+
}
265+
266+
scheduleNextHandleTasks(timeToWait);
267+
} catch (Exception e) {
268+
logger.warn("Encountered an error while handling tasks", e);
269+
cleanup();
270+
}
266271
}
267272

268273
private void notifyRequestsOfShutdown() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
5151
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
5252
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
53+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
5354
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
5455
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
5556
import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER;
@@ -90,7 +91,7 @@ public void shutdown() throws IOException, InterruptedException {
9091
}
9192

9293
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
93-
var senderFactory = createSenderFactory(clientManager, threadRef);
94+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
9495

9596
try (var sender = createSender(senderFactory)) {
9697
sender.start();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import static org.mockito.ArgumentMatchers.any;
5252
import static org.mockito.ArgumentMatchers.anyInt;
5353
import static org.mockito.Mockito.doAnswer;
54-
import static org.mockito.Mockito.doThrow;
5554
import static org.mockito.Mockito.mock;
5655
import static org.mockito.Mockito.times;
5756
import static org.mockito.Mockito.verify;
@@ -206,7 +205,7 @@ public void testExecute_Throws_WhenQueueIsFull() {
206205
assertFalse(thrownException.isExecutorShutdown());
207206
}
208207

209-
public void testTaskThrowsError_CallsOnFailure() {
208+
public void testTaskThrowsError_CallsOnFailure() throws InterruptedException {
210209
var requestSender = mock(RetryingHttpSender.class);
211210

212211
var service = createRequestExecutorService(null, requestSender);
@@ -229,6 +228,8 @@ public void testTaskThrowsError_CallsOnFailure() {
229228
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
230229
assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id")));
231230
assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class));
231+
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
232+
232233
assertTrue(service.isTerminated());
233234
}
234235

@@ -361,7 +362,6 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
361362
createRequestExecutorServiceSettingsEmpty(),
362363
requestSender,
363364
Clock.systemUTC(),
364-
RequestExecutorService.DEFAULT_SLEEPER,
365365
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
366366
);
367367

@@ -375,36 +375,7 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
375375
});
376376
service.start();
377377

378-
assertTrue(service.isTerminated());
379-
}
380-
381-
public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
382-
@SuppressWarnings("unchecked")
383-
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
384-
var sleeper = mock(RequestExecutorService.Sleeper.class);
385-
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
386-
387-
var service = new RequestExecutorService(
388-
threadPool,
389-
mockQueueCreator(queue),
390-
null,
391-
createRequestExecutorServiceSettingsEmpty(),
392-
mock(RetryingHttpSender.class),
393-
Clock.systemUTC(),
394-
sleeper,
395-
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
396-
);
397-
398-
Future<?> executorTermination = threadPool.generic().submit(() -> {
399-
try {
400-
service.start();
401-
} catch (Exception e) {
402-
fail(Strings.format("Failed to shutdown executor: %s", e));
403-
}
404-
});
405-
406-
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
407-
378+
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
408379
assertTrue(service.isTerminated());
409380
}
410381

@@ -581,7 +552,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens() {
581552
settings,
582553
requestSender,
583554
Clock.systemUTC(),
584-
RequestExecutorService.DEFAULT_SLEEPER,
585555
rateLimiterCreator
586556
);
587557
var requestManager = RequestManagerTests.createMock(requestSender);
@@ -614,7 +584,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And
614584
settings,
615585
requestSender,
616586
Clock.systemUTC(),
617-
RequestExecutorService.DEFAULT_SLEEPER,
618587
rateLimiterCreator
619588
);
620589
var requestManager = RequestManagerTests.createMock(requestSender);
@@ -626,11 +595,15 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And
626595

627596
doAnswer(invocation -> {
628597
service.shutdown();
598+
ActionListener<InferenceServiceResults> passedListener = invocation.getArgument(4);
599+
passedListener.onResponse(null);
600+
629601
return Void.TYPE;
630602
}).when(requestSender).send(any(), any(), any(), any(), any());
631603

632604
service.start();
633605

606+
listener.actionGet(TIMEOUT);
634607
verify(requestSender, times(1)).send(any(), any(), any(), any(), any());
635608
}
636609

@@ -648,7 +621,6 @@ public void testRemovesRateLimitGroup_AfterStaleDuration() {
648621
settings,
649622
requestSender,
650623
clock,
651-
RequestExecutorService.DEFAULT_SLEEPER,
652624
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
653625
);
654626
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
@@ -682,7 +654,6 @@ public void testStartsCleanupThread() {
682654
settings,
683655
requestSender,
684656
Clock.systemUTC(),
685-
RequestExecutorService.DEFAULT_SLEEPER,
686657
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
687658
);
688659

0 commit comments

Comments
 (0)