-
Notifications
You must be signed in to change notification settings - Fork 25.5k
[Inference API] Remove worst-case additional 50ms latency for non-rate limited requests #136167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
f1e7419
a326e31
191029f
edaa0f0
57c5605
f98b82e
f17ec00
90ee1a1
a9e7610
ec513be
174526c
f506cb3
ae349fd
540f49d
0dca88a
0590561
2e65475
4fb2372
2930151
b2fd85f
91f387a
90e672b
1cf24dc
a92a7c0
8e52c22
a868152
dd53fcb
c575eba
69db0e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pr: 136167 | ||
summary: "[Inference API] Remove worst-case additional 50ms latency for non-rate limited\ | ||
\ requests" | ||
area: Machine Learning | ||
type: bug | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.action.support.ContextPreservingActionListener; | ||
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; | ||
import org.elasticsearch.common.util.concurrent.FutureUtils; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.core.Strings; | ||
import org.elasticsearch.core.TimeValue; | ||
|
@@ -33,6 +34,7 @@ | |
import java.util.concurrent.ConcurrentHashMap; | ||
import java.util.concurrent.ConcurrentMap; | ||
import java.util.concurrent.CountDownLatch; | ||
import java.util.concurrent.Future; | ||
import java.util.concurrent.LinkedBlockingQueue; | ||
import java.util.concurrent.TimeUnit; | ||
import java.util.concurrent.atomic.AtomicBoolean; | ||
|
@@ -53,6 +55,29 @@ | |
* {@link org.apache.http.client.methods.HttpUriRequest} to set a timeout for how long this executor will wait | ||
* attempting to execute a task (aka waiting for the connection manager to lease a connection). See | ||
* {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. | ||
* | ||
* The request flow looks as follows: | ||
* | ||
* -------------> Execute request immediately. | ||
* | | ||
* | | ||
* request NOT supporting | ||
* rate limiting | ||
* | | ||
* | | ||
* Request ---> [-Request Queue-] | ||
* | | ||
* | | ||
* request supporting | ||
* rate limiting | ||
* | | ||
* | | ||
* ------------> {Rate Limit Group 1 -> Queue 1, ..., Rate Limit Group N -> Queue N} | ||
* | ||
* Explanation: Submit request to the queue for the specific rate limiting group. | ||
* The rate limiting groups are polled at the same specified interval, | ||
* which in the worst cases introduces an additional latency of | ||
* {@link RequestExecutorServiceSettings#getTaskPollFrequency()}. | ||
*/ | ||
public class RequestExecutorService implements RequestExecutor { | ||
|
||
|
@@ -109,6 +134,8 @@ interface RateLimiterCreator { | |
private final RateLimiterCreator rateLimiterCreator; | ||
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>(); | ||
private final AtomicBoolean started = new AtomicBoolean(false); | ||
private final AdjustableCapacityBlockingQueue<RejectableTask> requestQueue; | ||
private volatile Future<?> requestQueueTask; | ||
|
||
public RequestExecutorService( | ||
ThreadPool threadPool, | ||
|
@@ -135,10 +162,16 @@ public RequestExecutorService( | |
this.settings = Objects.requireNonNull(settings); | ||
this.clock = Objects.requireNonNull(clock); | ||
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator); | ||
this.requestQueue = new AdjustableCapacityBlockingQueue<>(queueCreator, settings.getQueueCapacity()); | ||
} | ||
|
||
public void shutdown() { | ||
if (shutdown.compareAndSet(false, true)) { | ||
if (requestQueueTask != null) { | ||
boolean cancelled = FutureUtils.cancel(requestQueueTask); | ||
logger.debug(() -> format("Request queue cancellation successful: %s", cancelled)); | ||
} | ||
|
||
if (cancellableCleanupTask.get() != null) { | ||
logger.debug(() -> "Stopping clean up thread"); | ||
cancellableCleanupTask.get().cancel(); | ||
|
@@ -159,7 +192,7 @@ public boolean isTerminated() { | |
} | ||
|
||
public int queueSize() { | ||
return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); | ||
return requestQueue.size() + rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); | ||
} | ||
|
||
/** | ||
|
@@ -175,8 +208,8 @@ public void start() { | |
|
||
startCleanupTask(); | ||
signalStartInitiated(); | ||
|
||
handleTasks(); | ||
startRequestQueueTask(); | ||
jonathan-buttner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
startHandlingRateLimitedTasks(); | ||
} catch (Exception e) { | ||
logger.warn("Failed to start request executor", e); | ||
cleanup(); | ||
|
||
|
@@ -194,6 +227,11 @@ private void startCleanupTask() { | |
cancellableCleanupTask.set(startCleanupThread(RATE_LIMIT_GROUP_CLEANUP_INTERVAL)); | ||
} | ||
|
||
private void startRequestQueueTask() { | ||
assert requestQueueTask == null : "The request queue can only be started once"; | ||
requestQueueTask = threadPool.executor(UTILITY_THREAD_POOL_NAME).submit(this::processRequestQueue); | ||
} | ||
|
||
private Scheduler.Cancellable startCleanupThread(TimeValue interval) { | ||
logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", interval)); | ||
|
||
|
@@ -221,7 +259,83 @@ private void scheduleNextHandleTasks(TimeValue timeToWait) { | |
return; | ||
} | ||
|
||
threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME)); | ||
threadPool.schedule(this::startHandlingRateLimitedTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME)); | ||
} | ||
|
||
private void processRequestQueue() { | ||
try { | ||
while (isShutdown() == false) { | ||
// Blocks the request queue thread until a new request comes in | ||
var task = (RequestTask) requestQueue.take(); | ||
|
||
|
||
if (isShutdown()) { | ||
logger.debug("Shutdown requested while handling request tasks, cleaning up"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're shutting down we need to reject the task we pulled off the Here's an example of doing that: https://github.com/elastic/elasticsearch/blob/8.13/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java#L192 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adjusted with Reject request on shutdown |
||
cleanup(); | ||
|
||
return; | ||
} else { | ||
var requestManager = task.getRequestManager(); | ||
|
||
if (rateLimitingEnabled(requestManager)) { | ||
submitTaskToRateLimitedExecutionPath(task); | ||
} else { | ||
executeTaskImmediately(task); | ||
} | ||
} | ||
} | ||
} catch (InterruptedException e) { | ||
// Restore interrupt to propagate to the calling thread | ||
Thread.currentThread().interrupt(); | ||
logger.debug("Inference request queue interrupted, exiting"); | ||
} catch (Exception e) { | ||
logger.warn("Error processing task in inference request queue", e); | ||
cleanup(); | ||
|
||
} | ||
} | ||
|
||
private void executeTaskImmediately(RequestTask task) { | ||
try { | ||
task.getRequestManager() | ||
.execute(task.getInferenceInputs(), requestSender, task.getRequestCompletedFunction(), task.getListener()); | ||
} catch (Exception e) { | ||
logger.warn( | ||
format("Failed to execute fast-path request for inference id [%s]", task.getRequestManager().inferenceEntityId()), | ||
e | ||
); | ||
|
||
task.onRejection( | ||
new EsRejectedExecutionException( | ||
format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()), | ||
false | ||
) | ||
); | ||
} | ||
} | ||
|
||
// visible for testing | ||
void submitTaskToRateLimitedExecutionPath(RequestTask task) { | ||
var requestManager = task.getRequestManager(); | ||
var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> { | ||
var endpointHandler = new RateLimitingEndpointHandler( | ||
Integer.toString(requestManager.rateLimitGrouping().hashCode()), | ||
queueCreator, | ||
settings, | ||
requestSender, | ||
clock, | ||
requestManager.rateLimitSettings(), | ||
this::isShutdown, | ||
rateLimiterCreator, | ||
rateLimitDivisor.get() | ||
); | ||
|
||
endpointHandler.init(); | ||
return endpointHandler; | ||
}); | ||
|
||
endpoint.enqueue(task); | ||
} | ||
|
||
private boolean rateLimitingEnabled(RequestManager requestManager) { | ||
|
||
return requestManager.rateLimitSettings() != null && requestManager.rateLimitSettings().isEnabled(); | ||
} | ||
|
||
private void cleanup() { | ||
|
@@ -234,12 +348,12 @@ private void cleanup() { | |
} | ||
} | ||
|
||
private void handleTasks() { | ||
private void startHandlingRateLimitedTasks() { | ||
try { | ||
TimeValue timeToWait; | ||
do { | ||
if (shutdown.get()) { | ||
logger.debug("Shutdown requested while handling tasks, cleaning up"); | ||
if (isShutdown()) { | ||
logger.debug("Shutdown requested while handling rate limited tasks, cleaning up"); | ||
cleanup(); | ||
return; | ||
} | ||
|
@@ -253,17 +367,47 @@ private void handleTasks() { | |
|
||
scheduleNextHandleTasks(timeToWait); | ||
} catch (Exception e) { | ||
logger.warn("Encountered an error while handling tasks", e); | ||
logger.warn("Encountered an error while handling rate limited tasks", e); | ||
cleanup(); | ||
} | ||
} | ||
|
||
private void notifyRequestsOfShutdown() { | ||
assert isShutdown() : "Requests should only be notified if the executor is shutting down"; | ||
|
||
// Reject rate-limited requests | ||
for (var endpoint : rateLimitGroupings.values()) { | ||
endpoint.notifyRequestsOfShutdown(); | ||
} | ||
|
||
// Reject non-rate-limited requests | ||
List<RejectableTask> requests = new ArrayList<>(); | ||
requestQueue.drainTo(requests); | ||
|
||
for (var request : requests) { | ||
rejectRequest(request); | ||
} | ||
} | ||
|
||
private void rejectRequest(RejectableTask task) { | ||
try { | ||
task.onRejection( | ||
new EsRejectedExecutionException( | ||
format( | ||
"Failed to send request for inference id [%s] has shutdown prior to executing request", | ||
task.getRequestManager().inferenceEntityId() | ||
), | ||
true | ||
) | ||
); | ||
} catch (Exception e) { | ||
logger.warn( | ||
format( | ||
"Failed to notify request for inference id [%s] of rejection after executor service shutdown", | ||
task.getRequestManager().inferenceEntityId() | ||
) | ||
); | ||
} | ||
|
||
} | ||
|
||
// default for testing | ||
|
@@ -308,26 +452,29 @@ public void execute( | |
ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()) | ||
); | ||
|
||
var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> { | ||
var endpointHandler = new RateLimitingEndpointHandler( | ||
Integer.toString(requestManager.rateLimitGrouping().hashCode()), | ||
queueCreator, | ||
settings, | ||
requestSender, | ||
clock, | ||
requestManager.rateLimitSettings(), | ||
this::isShutdown, | ||
rateLimiterCreator, | ||
rateLimitDivisor.get() | ||
if (isShutdown()) { | ||
task.onRejection( | ||
new EsRejectedExecutionException( | ||
format( | ||
"Failed to enqueue request task for inference id [%s] because the request executor service has been shutdown", | ||
requestManager.inferenceEntityId() | ||
), | ||
true | ||
) | ||
); | ||
return; | ||
} | ||
|
||
// TODO: add or create/compute if absent set for new map (service/task-type-key -> rate limit endpoint handler) | ||
boolean taskAccepted = requestQueue.offer(task); | ||
|
||
|
||
endpointHandler.init(); | ||
return endpointHandler; | ||
}); | ||
|
||
endpoint.enqueue(task); | ||
if (taskAccepted == false) { | ||
|
||
task.onRejection( | ||
new EsRejectedExecutionException( | ||
format("Failed to enqueue request task for inference id [%s]", requestManager.inferenceEntityId()), | ||
false | ||
) | ||
); | ||
} | ||
} | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, I think it's up to our implementation to check if it is canceled. So I think we'll get stuck in the
queue.take()
🤔It doesn't seem like
FutureUtils.cancel()
will do an interrupt.This is how we've handled that in the past: https://github.com/elastic/elasticsearch/blob/8.13/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java#L253
The
shutdown()
method puts a noop task on the queue to ensure that it wakes up.Can you double check the tests and make sure we're covering this case (we call shutdown and then await termination)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adjusted with Add NoopTask to wake up queue on shutdown
AFAIU we always check that when calling submitShutdownRequest, right?