-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Adding asynchronous start up logic for the inference API internals #135462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
5833e3e
c77289e
ac3f29c
e3c9103
de0251e
a901dc6
a5d9a4e
67cd5c4
7b802ce
1c6d667
0ea22a0
ff26624
3779c49
3277d23
ef179b3
8d40454
2110e95
0eba997
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
|
|
||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.action.support.SubscribableListener; | ||
| import org.elasticsearch.cluster.service.ClusterService; | ||
| import org.elasticsearch.common.ValidationException; | ||
| import org.elasticsearch.core.IOUtils; | ||
|
|
@@ -73,10 +74,11 @@ public void infer( | |
| @Nullable TimeValue timeout, | ||
| ActionListener<InferenceServiceResults> listener | ||
| ) { | ||
| timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService); | ||
| init(); | ||
| var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); | ||
| doInfer(model, inferenceInput, taskSettings, timeout, listener); | ||
| SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((inferListener) -> { | ||
| var resolvedInferenceTimeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService); | ||
| var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); | ||
| doInfer(model, inferenceInput, taskSettings, resolvedInferenceTimeout, inferListener); | ||
| }).addListener(listener); | ||
| } | ||
|
|
||
| private static InferenceInputs createInput( | ||
|
|
@@ -121,8 +123,9 @@ public void unifiedCompletionInfer( | |
| TimeValue timeout, | ||
| ActionListener<InferenceServiceResults> listener | ||
| ) { | ||
| init(); | ||
| doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener); | ||
| SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((completionInferListener) -> { | ||
| doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, completionInferListener); | ||
| }).addListener(listener); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -135,16 +138,16 @@ public void chunkedInfer( | |
| TimeValue timeout, | ||
| ActionListener<List<ChunkedInference>> listener | ||
| ) { | ||
| init(); | ||
|
|
||
| ValidationException validationException = new ValidationException(); | ||
| validateInputType(inputType, model, validationException); | ||
| if (validationException.validationErrors().isEmpty() == false) { | ||
| throw validationException; | ||
| } | ||
| SubscribableListener.newForked(this::init).<List<ChunkedInference>>andThen((chunkedInferListener) -> { | ||
| ValidationException validationException = new ValidationException(); | ||
| validateInputType(inputType, model, validationException); | ||
| if (validationException.validationErrors().isEmpty() == false) { | ||
| throw validationException; | ||
| } | ||
|
|
||
| // a non-null query is not supported and is dropped by all providers | ||
| doChunkedInfer(model, input, taskSettings, inputType, timeout, listener); | ||
| // a non-null query is not supported and is dropped by all providers | ||
| doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener); | ||
| }).addListener(listener); | ||
| } | ||
|
|
||
| protected abstract void doInfer( | ||
|
|
@@ -176,8 +179,9 @@ protected abstract void doChunkedInfer( | |
| ); | ||
|
|
||
| public void start(Model model, ActionListener<Boolean> listener) { | ||
| init(); | ||
| doStart(model, listener); | ||
| SubscribableListener.newForked(this::init) | ||
| .<Boolean>andThen((doStartListener) -> doStart(model, doStartListener)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the purpose of calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea is that it can be overridden by child classes. In reality I don't think any actually override it yet. The Elasticsearch integration does use it but that doesn't extend from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha, thanks for the explanation |
||
| .addListener(listener); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -189,8 +193,8 @@ protected void doStart(Model model, ActionListener<Boolean> listener) { | |
| listener.onResponse(true); | ||
| } | ||
|
|
||
| private void init() { | ||
| sender.start(); | ||
| private void init(ActionListener<Void> listener) { | ||
| sender.startAsynchronously(listener); | ||
| } | ||
|
|
||
| @Override | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,7 +75,7 @@ public Factory( | |
|
|
||
| public Sender createSender() { | ||
| // ensure this is started | ||
| bedrockRequestSender.start(); | ||
| bedrockRequestSender.startSynchronously(); | ||
| return bedrockRequestSender; | ||
| } | ||
| } | ||
|
|
@@ -97,8 +97,17 @@ protected AmazonBedrockRequestSender( | |
| this.startCompleted = Objects.requireNonNull(startCompleted); | ||
| } | ||
|
|
||
| /** | ||
| * TODO implement this functionality to ensure that we don't block node bootups | ||
|
||
| * See: https://github.com/elastic/ml-team/issues/1701 | ||
| */ | ||
| @Override | ||
| public void start() { | ||
| public void startAsynchronously(ActionListener<Void> listener) { | ||
| throw new UnsupportedOperationException("not implemented"); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be worth wrapping this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I think in that situation we should still throw. It would be a bug if we're ever calling that method for |
||
| } | ||
|
|
||
| @Override | ||
| public void startSynchronously() { | ||
| if (started.compareAndSet(false, true)) { | ||
| // The manager must be started before the executor service. That way we guarantee that the http client | ||
| // is ready prior to the service attempting to use the http client to send a request | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| import org.elasticsearch.ElasticsearchException; | ||
| import org.elasticsearch.ExceptionsHelper; | ||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.action.support.SubscribableListener; | ||
| import org.elasticsearch.common.Strings; | ||
| import org.elasticsearch.core.Nullable; | ||
| import org.elasticsearch.core.TimeValue; | ||
|
|
@@ -82,37 +83,33 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization | |
| return; | ||
| } | ||
|
|
||
| // ensure that the sender is initialized | ||
| sender.start(); | ||
|
|
||
| ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> { | ||
| if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { | ||
| logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); | ||
| listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity)); | ||
| } else { | ||
| var errorMessage = Strings.format( | ||
| "%s Received an invalid response type from the Elastic Inference Service: %s", | ||
| FAILED_TO_RETRIEVE_MESSAGE, | ||
| results.getClass().getSimpleName() | ||
| ); | ||
|
|
||
| logger.warn(errorMessage); | ||
| listener.onFailure(new ElasticsearchException(errorMessage)); | ||
| } | ||
| requestCompleteLatch.countDown(); | ||
| }, e -> { | ||
| var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> { | ||
| // unwrap because it's likely a retry exception | ||
| var exception = ExceptionsHelper.unwrapCause(e); | ||
|
|
||
| logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception); | ||
| listener.onFailure(e); | ||
| requestCompleteLatch.countDown(); | ||
| authModelListener.onFailure(e); | ||
| }); | ||
|
|
||
| var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); | ||
| var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata); | ||
| SubscribableListener.newForked(sender::startAsynchronously).<InferenceServiceResults>andThen((authListener) -> { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now we're doing an async start and then once that completes we do the rest of the functionality as normal. |
||
| var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); | ||
| var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata); | ||
| sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); | ||
| }).andThenApply(authResult -> { | ||
| if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { | ||
| logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); | ||
| return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity); | ||
| } | ||
|
|
||
| var errorMessage = Strings.format( | ||
| "%s Received an invalid response type from the Elastic Inference Service: %s", | ||
| FAILED_TO_RETRIEVE_MESSAGE, | ||
| authResult.getClass().getSimpleName() | ||
| ); | ||
|
|
||
| sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener); | ||
| logger.warn(errorMessage); | ||
| throw new ElasticsearchException(errorMessage); | ||
| }).addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown)); | ||
| } catch (Exception e) { | ||
| logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e)); | ||
| requestCompleteLatch.countDown(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,17 +105,17 @@ public void testCreateSender_CanCallStartMultipleTimes() throws Exception { | |
| var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); | ||
|
|
||
| try (var sender = createSender(senderFactory)) { | ||
| sender.start(); | ||
| sender.start(); | ||
| sender.start(); | ||
| sender.startSynchronously(); | ||
| sender.startSynchronously(); | ||
| sender.startSynchronously(); | ||
|
Comment on lines
+117
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to add some tests for the |
||
| } | ||
| } | ||
|
|
||
| public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { | ||
| var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); | ||
|
|
||
| try (var sender = createSender(senderFactory)) { | ||
| sender.start(); | ||
| sender.startSynchronously(); | ||
|
|
||
| String responseJson = """ | ||
| { | ||
|
|
@@ -167,7 +167,7 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce | |
| var senderFactory = createSenderFactory(clientManager, threadRef); | ||
|
|
||
| try (var sender = createSender(senderFactory)) { | ||
| sender.start(); | ||
| sender.startSynchronously(); | ||
|
|
||
| String responseJson = """ | ||
| { | ||
|
|
@@ -240,7 +240,7 @@ public void testHttpRequestSender_Throws_WhenATimeoutOccurs() throws Exception { | |
|
|
||
| try (var sender = senderFactory.createSender()) { | ||
| assertThat(sender, instanceOf(HttpRequestSender.class)); | ||
| sender.start(); | ||
| sender.startSynchronously(); | ||
|
|
||
| PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); | ||
| sender.send(RequestManagerTests.createMock(), new EmbeddingsInput(List.of(), null), TimeValue.timeValueNanos(1), listener); | ||
|
|
@@ -263,7 +263,7 @@ public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws | |
| ); | ||
|
|
||
| try (var sender = senderFactory.createSender()) { | ||
| sender.start(); | ||
| sender.startSynchronously(); | ||
|
|
||
| PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); | ||
| sender.send(RequestManagerTests.createMock(), new EmbeddingsInput(List.of(), null), TimeValue.timeValueNanos(1), listener); | ||
|
|
@@ -286,7 +286,7 @@ public void testSendWithoutQueuingWithTimeout_Throws_WhenATimeoutOccurs() throws | |
| ); | ||
|
|
||
| try (var sender = senderFactory.createSender()) { | ||
| sender.start(); | ||
| sender.startSynchronously(); | ||
|
|
||
| PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); | ||
| sender.sendWithoutQueuing( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this cause any exception thrown in
startInternal()to be ignored when doing a synchronous start? Also, do we need to make sure that we always callwaitForStartToComplete()before returning from this method? If someone callsstartAsynchronously()then another thread immediately callsstartSynchronously(), the second call will return immediately (because we already setstartedto true) but the sender won't actually have started yet.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I'll make those changes.