-
Notifications
You must be signed in to change notification settings - Fork 25.5k
[Inference API] Remove worst-case additional 50ms latency for non-rate limited requests #136167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
f1e7419
a326e31
191029f
edaa0f0
57c5605
f98b82e
f17ec00
90ee1a1
a9e7610
ec513be
174526c
f506cb3
ae349fd
540f49d
0dca88a
0590561
2e65475
4fb2372
2930151
b2fd85f
91f387a
90e672b
1cf24dc
a92a7c0
8e52c22
a868152
dd53fcb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pr: 136167 | ||
summary: "[Inference API] Remove worst-case additional 50ms latency for non-rate limited\ | ||
\ requests" | ||
area: Machine Learning | ||
type: bug | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
import java.util.concurrent.ConcurrentHashMap; | ||
import java.util.concurrent.ConcurrentMap; | ||
import java.util.concurrent.CountDownLatch; | ||
import java.util.concurrent.Future; | ||
import java.util.concurrent.LinkedBlockingQueue; | ||
import java.util.concurrent.TimeUnit; | ||
import java.util.concurrent.atomic.AtomicBoolean; | ||
|
@@ -53,6 +54,29 @@ | |
* {@link org.apache.http.client.methods.HttpUriRequest} to set a timeout for how long this executor will wait | ||
* attempting to execute a task (aka waiting for the connection manager to lease a connection). See | ||
* {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. | ||
* | ||
* The request flow looks as follows: | ||
* | ||
* -------------> Add request to fast-path request queue. | ||
* | | ||
* | | ||
* request NOT supporting | ||
* rate limiting | ||
* | | ||
* | | ||
* Request ------------| | ||
* | | ||
* | | ||
* request supporting | ||
* rate limiting | ||
* | | ||
* | | ||
* ------------> {Rate Limit Group 1 -> Queue 1, ..., Rate Limit Group N -> Queue N} | ||
* | ||
* Explanation: Submit request to the queue for the specific rate limiting group. | ||
* The rate limiting groups are polled at the same specified interval, | ||
* which in the worst cases introduces an additional latency of | ||
* {@link RequestExecutorServiceSettings#getTaskPollFrequency()}. | ||
*/ | ||
public class RequestExecutorService implements RequestExecutor { | ||
|
||
|
@@ -109,6 +133,8 @@ interface RateLimiterCreator { | |
private final RateLimiterCreator rateLimiterCreator; | ||
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>(); | ||
private final AtomicBoolean started = new AtomicBoolean(false); | ||
private final AdjustableCapacityBlockingQueue<RejectableTask> requestQueue; | ||
private volatile Future<?> requestQueueTask; | ||
|
||
public RequestExecutorService( | ||
ThreadPool threadPool, | ||
|
@@ -135,10 +161,16 @@ public RequestExecutorService( | |
this.settings = Objects.requireNonNull(settings); | ||
this.clock = Objects.requireNonNull(clock); | ||
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator); | ||
this.requestQueue = new AdjustableCapacityBlockingQueue<>(queueCreator, settings.getQueueCapacity()); | ||
} | ||
|
||
public void shutdown() { | ||
if (shutdown.compareAndSet(false, true)) { | ||
if (requestQueueTask != null) { | ||
// Wakes up the queue in processRequestQueue | ||
requestQueue.offer(NoopTask); | ||
} | ||
|
||
if (cancellableCleanupTask.get() != null) { | ||
logger.debug(() -> "Stopping clean up thread"); | ||
cancellableCleanupTask.get().cancel(); | ||
|
@@ -159,7 +191,7 @@ public boolean isTerminated() { | |
} | ||
|
||
public int queueSize() { | ||
return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); | ||
return requestQueue.size() + rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); | ||
} | ||
|
||
/** | ||
|
@@ -174,9 +206,9 @@ public void start() { | |
started.set(true); | ||
|
||
startCleanupTask(); | ||
startRequestQueueTask(); | ||
signalStartInitiated(); | ||
|
||
handleTasks(); | ||
startHandlingRateLimitedTasks(); | ||
} catch (Exception e) { | ||
logger.warn("Failed to start request executor", e); | ||
cleanup(); | ||
|
||
|
@@ -194,6 +226,11 @@ private void startCleanupTask() { | |
cancellableCleanupTask.set(startCleanupThread(RATE_LIMIT_GROUP_CLEANUP_INTERVAL)); | ||
} | ||
|
||
private void startRequestQueueTask() { | ||
assert requestQueueTask == null : "The request queue can only be started once"; | ||
requestQueueTask = threadPool.executor(UTILITY_THREAD_POOL_NAME).submit(this::processRequestQueue); | ||
} | ||
|
||
private Scheduler.Cancellable startCleanupThread(TimeValue interval) { | ||
logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", interval)); | ||
|
||
|
@@ -221,7 +258,86 @@ private void scheduleNextHandleTasks(TimeValue timeToWait) { | |
return; | ||
} | ||
|
||
threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME)); | ||
threadPool.schedule(this::startHandlingRateLimitedTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME)); | ||
} | ||
|
||
private void processRequestQueue() { | ||
try { | ||
while (isShutdown() == false) { | ||
var task = requestQueue.take(); | ||
|
||
if (task == NoopTask) { | ||
if (isShutdown()) { | ||
logger.debug("Shutdown requested, exiting request queue processing"); | ||
break; | ||
} | ||
|
||
// Skip processing NoopTask | ||
continue; | ||
} | ||
|
||
if (isShutdown()) { | ||
logger.debug("Shutdown requested while handling request tasks, cleaning up"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're shutting down we need to reject the task we pulled off the Here's an example of doing that: https://github.com/elastic/elasticsearch/blob/8.13/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java#L192 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adjusted with Reject request on shutdown |
||
rejectNonRateLimitedRequest(task); | ||
break; | ||
} | ||
|
||
executeTaskImmediately(task); | ||
} | ||
} catch (InterruptedException e) { | ||
Thread.currentThread().interrupt(); | ||
logger.debug("Inference request queue interrupted, exiting"); | ||
} catch (Exception e) { | ||
logger.error("Unexpected error processing request queue, terminating", e); | ||
} finally { | ||
cleanup(); | ||
|
||
} | ||
} | ||
|
||
private void executeTaskImmediately(RejectableTask task) { | ||
try { | ||
task.getRequestManager() | ||
.execute(task.getInferenceInputs(), requestSender, task.getRequestCompletedFunction(), task.getListener()); | ||
} catch (Exception e) { | ||
logger.warn( | ||
format("Failed to execute fast-path request for inference id [%s]", task.getRequestManager().inferenceEntityId()), | ||
e | ||
); | ||
|
||
task.onRejection( | ||
new EsRejectedExecutionException( | ||
format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()), | ||
false | ||
) | ||
); | ||
} | ||
} | ||
|
||
// visible for testing | ||
void submitTaskToRateLimitedExecutionPath(RequestTask task) { | ||
var requestManager = task.getRequestManager(); | ||
var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> { | ||
var endpointHandler = new RateLimitingEndpointHandler( | ||
Integer.toString(requestManager.rateLimitGrouping().hashCode()), | ||
queueCreator, | ||
settings, | ||
requestSender, | ||
clock, | ||
requestManager.rateLimitSettings(), | ||
this::isShutdown, | ||
rateLimiterCreator, | ||
rateLimitDivisor.get() | ||
); | ||
|
||
endpointHandler.init(); | ||
return endpointHandler; | ||
}); | ||
|
||
endpoint.enqueue(task); | ||
} | ||
|
||
private static boolean rateLimitingEnabled(RateLimitSettings rateLimitSettings) { | ||
return rateLimitSettings != null && rateLimitSettings.isEnabled(); | ||
} | ||
|
||
private void cleanup() { | ||
|
@@ -234,12 +350,12 @@ private void cleanup() { | |
} | ||
} | ||
|
||
private void handleTasks() { | ||
private void startHandlingRateLimitedTasks() { | ||
try { | ||
TimeValue timeToWait; | ||
do { | ||
if (shutdown.get()) { | ||
logger.debug("Shutdown requested while handling tasks, cleaning up"); | ||
if (isShutdown()) { | ||
logger.debug("Shutdown requested while handling rate limited tasks, cleaning up"); | ||
cleanup(); | ||
return; | ||
} | ||
|
@@ -253,17 +369,44 @@ private void handleTasks() { | |
|
||
scheduleNextHandleTasks(timeToWait); | ||
} catch (Exception e) { | ||
logger.warn("Encountered an error while handling tasks", e); | ||
logger.warn("Encountered an error while handling rate limited tasks", e); | ||
cleanup(); | ||
} | ||
} | ||
|
||
private void notifyRequestsOfShutdown() { | ||
assert isShutdown() : "Requests should only be notified if the executor is shutting down"; | ||
|
||
// Reject rate-limited requests | ||
for (var endpoint : rateLimitGroupings.values()) { | ||
endpoint.notifyRequestsOfShutdown(); | ||
} | ||
|
||
// Reject non-rate-limited requests | ||
List<RejectableTask> requests = new ArrayList<>(); | ||
requestQueue.drainTo(requests); | ||
|
||
for (var request : requests) { | ||
rejectNonRateLimitedRequest(request); | ||
} | ||
} | ||
|
||
private void rejectNonRateLimitedRequest(RejectableTask task) { | ||
var inferenceEntityId = task.getRequestManager().inferenceEntityId(); | ||
|
||
rejectRequest( | ||
task, | ||
format("Failed to send request for inference id [%s] has shutdown prior to executing request", inferenceEntityId), | ||
|
||
format("Failed to notify request for inference id [%s] of rejection after executor service shutdown", inferenceEntityId) | ||
); | ||
} | ||
|
||
private static void rejectRequest(RejectableTask task, String rejectionMessage, String rejectionFailedMessage) { | ||
try { | ||
task.onRejection(new EsRejectedExecutionException(rejectionMessage, true)); | ||
} catch (Exception e) { | ||
logger.warn(rejectionFailedMessage); | ||
} | ||
} | ||
|
||
// default for testing | ||
|
@@ -308,26 +451,33 @@ public void execute( | |
ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()) | ||
); | ||
|
||
var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> { | ||
var endpointHandler = new RateLimitingEndpointHandler( | ||
Integer.toString(requestManager.rateLimitGrouping().hashCode()), | ||
queueCreator, | ||
settings, | ||
requestSender, | ||
clock, | ||
requestManager.rateLimitSettings(), | ||
this::isShutdown, | ||
rateLimiterCreator, | ||
rateLimitDivisor.get() | ||
if (isShutdown()) { | ||
task.onRejection( | ||
new EsRejectedExecutionException( | ||
format( | ||
"Failed to enqueue request task for inference id [%s] because the request executor service has been shutdown", | ||
requestManager.inferenceEntityId() | ||
), | ||
true | ||
) | ||
); | ||
return; | ||
} | ||
|
||
// TODO: add or create/compute if absent set for new map (service/task-type-key -> rate limit endpoint handler) | ||
|
||
endpointHandler.init(); | ||
return endpointHandler; | ||
}); | ||
if (rateLimitingEnabled(requestManager.rateLimitSettings())) { | ||
submitTaskToRateLimitedExecutionPath(task); | ||
} else { | ||
boolean taskAccepted = requestQueue.offer(task); | ||
|
||
endpoint.enqueue(task); | ||
if (taskAccepted == false) { | ||
task.onRejection( | ||
new EsRejectedExecutionException( | ||
format("Failed to enqueue request task for inference id [%s]", requestManager.inferenceEntityId()), | ||
false | ||
) | ||
); | ||
} | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -423,7 +573,7 @@ public synchronized TimeValue executeEnqueuedTask() { | |
} | ||
|
||
private TimeValue executeEnqueuedTaskInternal() { | ||
if (rateLimitSettings.isEnabled()) { | ||
if (rateLimitingEnabled(rateLimitSettings)) { | ||
var timeBeforeAvailableToken = rateLimiter.timeToReserve(1); | ||
if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) { | ||
return timeBeforeAvailableToken; | ||
|
@@ -514,27 +664,18 @@ public synchronized void notifyRequestsOfShutdown() { | |
|
||
private void rejectTasks(List<RejectableTask> tasks) { | ||
for (var task : tasks) { | ||
rejectTaskForShutdown(task); | ||
} | ||
} | ||
var inferenceEntityId = task.getRequestManager().inferenceEntityId(); | ||
|
||
private void rejectTaskForShutdown(RejectableTask task) { | ||
try { | ||
task.onRejection( | ||
new EsRejectedExecutionException( | ||
format( | ||
"Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request", | ||
id, | ||
task.getRequestManager().inferenceEntityId() | ||
), | ||
true | ||
) | ||
); | ||
} catch (Exception e) { | ||
logger.warn( | ||
rejectRequest( | ||
task, | ||
format( | ||
"Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request", | ||
id, | ||
|
||
inferenceEntityId | ||
), | ||
format( | ||
"Failed to notify request for inference id [%s] of rejection after executor service grouping [%s] shutdown", | ||
task.getRequestManager().inferenceEntityId(), | ||
inferenceEntityId, | ||
id | ||
) | ||
); | ||
|
@@ -549,4 +690,37 @@ public void close() { | |
requestExecutorServiceSettings.deregisterQueueCapacityCallback(id); | ||
} | ||
} | ||
|
||
private static final RejectableTask NoopTask = new RejectableTask() { | ||
@Override | ||
public void onRejection(Exception e) { | ||
throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); | ||
} | ||
|
||
@Override | ||
public RequestManager getRequestManager() { | ||
throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); | ||
} | ||
|
||
@Override | ||
public InferenceInputs getInferenceInputs() { | ||
throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); | ||
} | ||
|
||
@Override | ||
public ActionListener<InferenceServiceResults> getListener() { | ||
throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); | ||
} | ||
|
||
@Override | ||
public boolean hasCompleted() { | ||
throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); | ||
} | ||
|
||
@Override | ||
public Supplier<Boolean> getRequestCompletedFunction() { | ||
throw new UnsupportedOperationException("NoopTask is a pure marker class for signals in the request queue"); | ||
} | ||
}; | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we move this above the
signalStartInitiated()
? That way all the threading stuff is done prior to the signal start.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adjusted with Move startRequestQueueTask before start signal