-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Adding asynchronous start up logic for the inference API internals #135462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5833e3e
c77289e
ac3f29c
e3c9103
de0251e
a901dc6
a5d9a4e
67cd5c4
7b802ce
1c6d667
0ea22a0
ff26624
3779c49
3277d23
ef179b3
8d40454
2110e95
0eba997
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
|
|
||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.action.support.SubscribableListener; | ||
| import org.elasticsearch.cluster.service.ClusterService; | ||
| import org.elasticsearch.common.ValidationException; | ||
| import org.elasticsearch.core.IOUtils; | ||
|
|
@@ -73,10 +74,11 @@ public void infer( | |
| @Nullable TimeValue timeout, | ||
| ActionListener<InferenceServiceResults> listener | ||
| ) { | ||
| timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService); | ||
| init(); | ||
| var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); | ||
| doInfer(model, inferenceInput, taskSettings, timeout, listener); | ||
| SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((inferListener) -> { | ||
| var resolvedInferenceTimeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService); | ||
| var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); | ||
| doInfer(model, inferenceInput, taskSettings, resolvedInferenceTimeout, inferListener); | ||
| }).addListener(listener); | ||
| } | ||
|
|
||
| private static InferenceInputs createInput( | ||
|
|
@@ -121,8 +123,9 @@ public void unifiedCompletionInfer( | |
| TimeValue timeout, | ||
| ActionListener<InferenceServiceResults> listener | ||
| ) { | ||
| init(); | ||
| doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener); | ||
| SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((completionInferListener) -> { | ||
| doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, completionInferListener); | ||
| }).addListener(listener); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -135,16 +138,16 @@ public void chunkedInfer( | |
| TimeValue timeout, | ||
| ActionListener<List<ChunkedInference>> listener | ||
| ) { | ||
| init(); | ||
|
|
||
| ValidationException validationException = new ValidationException(); | ||
| validateInputType(inputType, model, validationException); | ||
| if (validationException.validationErrors().isEmpty() == false) { | ||
| throw validationException; | ||
| } | ||
| SubscribableListener.newForked(this::init).<List<ChunkedInference>>andThen((chunkedInferListener) -> { | ||
| ValidationException validationException = new ValidationException(); | ||
| validateInputType(inputType, model, validationException); | ||
| if (validationException.validationErrors().isEmpty() == false) { | ||
| throw validationException; | ||
| } | ||
|
|
||
| // a non-null query is not supported and is dropped by all providers | ||
| doChunkedInfer(model, input, taskSettings, inputType, timeout, listener); | ||
| // a non-null query is not supported and is dropped by all providers | ||
| doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener); | ||
| }).addListener(listener); | ||
| } | ||
|
|
||
| protected abstract void doInfer( | ||
|
|
@@ -176,8 +179,9 @@ protected abstract void doChunkedInfer( | |
| ); | ||
|
|
||
| public void start(Model model, ActionListener<Boolean> listener) { | ||
| init(); | ||
| doStart(model, listener); | ||
| SubscribableListener.newForked(this::init) | ||
| .<Boolean>andThen((doStartListener) -> doStart(model, doStartListener)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the purpose of calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea is that it can be overridden by child classes. In reality I don't think any actually override it yet. The Elasticsearch integration does use it but that doesn't extend from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha, thanks for the explanation |
||
| .addListener(listener); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -189,8 +193,8 @@ protected void doStart(Model model, ActionListener<Boolean> listener) { | |
| listener.onResponse(true); | ||
| } | ||
|
|
||
| private void init() { | ||
| sender.start(); | ||
| private void init(ActionListener<Void> listener) { | ||
| sender.startAsynchronously(listener); | ||
| } | ||
|
|
||
| @Override | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,7 +75,7 @@ public Factory( | |
|
|
||
| public Sender createSender() { | ||
| // ensure this is started | ||
| bedrockRequestSender.start(); | ||
| bedrockRequestSender.startSynchronously(); | ||
| return bedrockRequestSender; | ||
| } | ||
| } | ||
|
|
@@ -98,7 +98,13 @@ protected AmazonBedrockRequestSender( | |
| } | ||
|
|
||
| @Override | ||
| public void start() { | ||
| public void startAsynchronously(ActionListener<Void> listener) { | ||
|
|
||
| throw new UnsupportedOperationException("not implemented"); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be worth wrapping this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I think in that situation we should still throw. It would be a bug if we're ever calling that method for |
||
| } | ||
|
|
||
| @Override | ||
| public void startSynchronously() { | ||
| if (started.compareAndSet(false, true)) { | ||
| // The manager must be started before the executor service. That way we guarantee that the http client | ||
| // is ready prior to the service attempting to use the http client to send a request | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| import org.elasticsearch.ElasticsearchException; | ||
| import org.elasticsearch.ExceptionsHelper; | ||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.action.support.SubscribableListener; | ||
| import org.elasticsearch.common.Strings; | ||
| import org.elasticsearch.core.Nullable; | ||
| import org.elasticsearch.core.TimeValue; | ||
|
|
@@ -82,37 +83,33 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization | |
| return; | ||
| } | ||
|
|
||
| // ensure that the sender is initialized | ||
| sender.start(); | ||
|
|
||
| ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> { | ||
| if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { | ||
| logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); | ||
| listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity)); | ||
| } else { | ||
| var errorMessage = Strings.format( | ||
| "%s Received an invalid response type from the Elastic Inference Service: %s", | ||
| FAILED_TO_RETRIEVE_MESSAGE, | ||
| results.getClass().getSimpleName() | ||
| ); | ||
|
|
||
| logger.warn(errorMessage); | ||
| listener.onFailure(new ElasticsearchException(errorMessage)); | ||
| } | ||
| requestCompleteLatch.countDown(); | ||
| }, e -> { | ||
| var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> { | ||
| // unwrap because it's likely a retry exception | ||
| var exception = ExceptionsHelper.unwrapCause(e); | ||
|
|
||
| logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception); | ||
| listener.onFailure(e); | ||
| requestCompleteLatch.countDown(); | ||
| authModelListener.onFailure(e); | ||
| }); | ||
|
|
||
| var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); | ||
| var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata); | ||
| SubscribableListener.newForked(sender::startAsynchronously).<InferenceServiceResults>andThen((authListener) -> { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now we're doing an async start and then once that completes we do the rest of the functionality as normal. |
||
| var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); | ||
| var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata); | ||
| sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); | ||
| }).andThenApply(authResult -> { | ||
| if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { | ||
| logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); | ||
| return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity); | ||
| } | ||
|
|
||
| var errorMessage = Strings.format( | ||
| "%s Received an invalid response type from the Elastic Inference Service: %s", | ||
| FAILED_TO_RETRIEVE_MESSAGE, | ||
| authResult.getClass().getSimpleName() | ||
| ); | ||
|
|
||
| sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener); | ||
| logger.warn(errorMessage); | ||
| throw new ElasticsearchException(errorMessage); | ||
| }).addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown)); | ||
| } catch (Exception e) { | ||
| logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e)); | ||
| requestCompleteLatch.countDown(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we need to do something similar for async calls, since if two async calls come in one after the other, the second one will complete immediately even if the first one hasn't finished starting the sender yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, I tried to come up with a solution that would avoid having to do spin up a thread to then call the
waitForStartToCompletesince most of the time it will simply return.