Skip to content

Commit 940e8c8

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Adding asynchronous start up logic for the inference API internals (#135462)
* Refactoring init to be async * Fixing bedrock tests * Fixing more tests * Fixing tests * [CI] Auto commit changes from spotless * Fixing typo * Adding more notes on bedrock * Addressing feedback * Adding exception handling * [CI] Auto commit changes from spotless * Refactoring async start * rename --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 9353842 commit 940e8c8

File tree

51 files changed

+622
-397
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+622
-397
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.external.http.sender;
99

10+
import org.apache.logging.log4j.LogManager;
1011
import org.apache.logging.log4j.Logger;
1112
import org.elasticsearch.action.ActionListener;
1213
import org.elasticsearch.action.support.ContextPreservingActionListener;
@@ -42,6 +43,7 @@ public class HttpRequestSender implements Sender {
4243
*/
4344
public static class Factory {
4445
private final HttpRequestSender httpRequestSender;
46+
private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);
4547

4648
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
4749
Objects.requireNonNull(serviceComponents);
@@ -68,7 +70,8 @@ public Factory(ServiceComponents serviceComponents, HttpClientManager httpClient
6870
httpClientManager,
6971
requestSender,
7072
service,
71-
startCompleted
73+
startCompleted,
74+
START_COMPLETED_WAIT_TIME
7275
);
7376
}
7477

@@ -77,45 +80,95 @@ public Sender createSender() {
7780
}
7881
}
7982

80-
private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);
83+
private static final Logger logger = LogManager.getLogger(HttpRequestSender.class);
8184

8285
private final ThreadPool threadPool;
8386
private final HttpClientManager manager;
84-
private final AtomicBoolean started = new AtomicBoolean(false);
87+
private final AtomicBoolean startInitiated = new AtomicBoolean(false);
88+
private final AtomicBoolean startCompleted = new AtomicBoolean(false);
8589
private final RequestSender requestSender;
8690
private final RequestExecutor service;
87-
private final CountDownLatch startCompleted;
91+
private final CountDownLatch startCompletedLatch;
92+
private final TimeValue startCompletedWaitTime;
8893

89-
private HttpRequestSender(
94+
// Visible for testing
95+
protected HttpRequestSender(
9096
ThreadPool threadPool,
9197
HttpClientManager httpClientManager,
9298
RequestSender requestSender,
9399
RequestExecutor service,
94-
CountDownLatch startCompleted
100+
CountDownLatch startCompletedLatch,
101+
TimeValue startCompletedWaitTime
95102
) {
96103
this.threadPool = Objects.requireNonNull(threadPool);
97104
this.manager = Objects.requireNonNull(httpClientManager);
98105
this.requestSender = Objects.requireNonNull(requestSender);
99106
this.service = Objects.requireNonNull(service);
100-
this.startCompleted = Objects.requireNonNull(startCompleted);
107+
this.startCompletedLatch = Objects.requireNonNull(startCompletedLatch);
108+
this.startCompletedWaitTime = Objects.requireNonNull(startCompletedWaitTime);
101109
}
102110

103111
/**
104-
* Start various internal services. This is required before sending requests.
112+
* Start various internal services asynchronously. This is required before sending requests.
105113
*/
106-
public void start() {
107-
if (started.compareAndSet(false, true)) {
114+
@Override
115+
public void startAsynchronously(ActionListener<Void> listener) {
116+
if (startInitiated.compareAndSet(false, true)) {
117+
var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
118+
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> startInternal(preservedListener));
119+
} else if (startCompleted.get() == false) {
120+
var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
121+
// wait on another thread so we don't potential block a transport thread
122+
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> waitForStartToCompleteWithListener(preservedListener));
123+
} else {
124+
listener.onResponse(null);
125+
}
126+
}
127+
128+
private void startInternal(ActionListener<Void> listener) {
129+
try {
108130
// The manager must be started before the executor service. That way we guarantee that the http client
109131
// is ready prior to the service attempting to use the http client to send a request
110132
manager.start();
111133
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(service::start);
112134
waitForStartToComplete();
135+
startCompleted.set(true);
136+
listener.onResponse(null);
137+
} catch (Exception ex) {
138+
listener.onFailure(ex);
139+
}
140+
}
141+
142+
private void waitForStartToCompleteWithListener(ActionListener<Void> listener) {
143+
try {
144+
waitForStartToComplete();
145+
listener.onResponse(null);
146+
} catch (Exception e) {
147+
listener.onFailure(e);
113148
}
114149
}
115150

151+
/**
152+
* Start various internal services. This is required before sending requests.
153+
*
154+
* NOTE: This method blocks until the startup is complete.
155+
*/
156+
@Override
157+
public void startSynchronously() {
158+
if (startInitiated.compareAndSet(false, true)) {
159+
ActionListener<Void> listener = ActionListener.wrap(
160+
unused -> {},
161+
exception -> logger.error("Http sender failed to start", exception)
162+
);
163+
startInternal(listener);
164+
}
165+
// Handle the case where start*() was already called and this would return immediately because the started flag is already true
166+
waitForStartToComplete();
167+
}
168+
116169
private void waitForStartToComplete() {
117170
try {
118-
if (startCompleted.await(START_COMPLETED_WAIT_TIME.getSeconds(), TimeUnit.SECONDS) == false) {
171+
if (startCompletedLatch.await(startCompletedWaitTime.getMillis(), TimeUnit.MILLISECONDS) == false) {
119172
throw new IllegalStateException("Http sender startup did not complete in time");
120173
}
121174
} catch (InterruptedException e) {
@@ -145,7 +198,7 @@ public void send(
145198
@Nullable TimeValue timeout,
146199
ActionListener<InferenceServiceResults> listener
147200
) {
148-
assert started.get() : "call start() before sending a request";
201+
assert startInitiated.get() : "call start() before sending a request";
149202
waitForStartToComplete();
150203
service.execute(requestCreator, inferenceInputs, timeout, listener);
151204
}
@@ -167,7 +220,7 @@ public void sendWithoutQueuing(
167220
@Nullable TimeValue timeout,
168221
ActionListener<InferenceServiceResults> listener
169222
) {
170-
assert started.get() : "call start() before sending a request";
223+
assert startInitiated.get() : "call start() before sending a request";
171224
waitForStartToComplete();
172225

173226
var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import java.io.Closeable;
1919

2020
public interface Sender extends Closeable {
21-
void start();
21+
void startSynchronously();
22+
23+
void startAsynchronously(ActionListener<Void> listener);
2224

2325
void send(
2426
RequestManager requestCreator,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.support.SubscribableListener;
1213
import org.elasticsearch.cluster.service.ClusterService;
1314
import org.elasticsearch.common.ValidationException;
1415
import org.elasticsearch.core.IOUtils;
@@ -73,10 +74,11 @@ public void infer(
7374
@Nullable TimeValue timeout,
7475
ActionListener<InferenceServiceResults> listener
7576
) {
76-
timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService);
77-
init();
78-
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
79-
doInfer(model, inferenceInput, taskSettings, timeout, listener);
77+
SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((inferListener) -> {
78+
var resolvedInferenceTimeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService);
79+
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
80+
doInfer(model, inferenceInput, taskSettings, resolvedInferenceTimeout, inferListener);
81+
}).addListener(listener);
8082
}
8183

8284
private static InferenceInputs createInput(
@@ -121,8 +123,9 @@ public void unifiedCompletionInfer(
121123
TimeValue timeout,
122124
ActionListener<InferenceServiceResults> listener
123125
) {
124-
init();
125-
doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener);
126+
SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((completionInferListener) -> {
127+
doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, completionInferListener);
128+
}).addListener(listener);
126129
}
127130

128131
@Override
@@ -135,16 +138,16 @@ public void chunkedInfer(
135138
TimeValue timeout,
136139
ActionListener<List<ChunkedInference>> listener
137140
) {
138-
init();
139-
140-
ValidationException validationException = new ValidationException();
141-
validateInputType(inputType, model, validationException);
142-
if (validationException.validationErrors().isEmpty() == false) {
143-
throw validationException;
144-
}
141+
SubscribableListener.newForked(this::init).<List<ChunkedInference>>andThen((chunkedInferListener) -> {
142+
ValidationException validationException = new ValidationException();
143+
validateInputType(inputType, model, validationException);
144+
if (validationException.validationErrors().isEmpty() == false) {
145+
throw validationException;
146+
}
145147

146-
// a non-null query is not supported and is dropped by all providers
147-
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
148+
// a non-null query is not supported and is dropped by all providers
149+
doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener);
150+
}).addListener(listener);
148151
}
149152

150153
protected abstract void doInfer(
@@ -176,8 +179,9 @@ protected abstract void doChunkedInfer(
176179
);
177180

178181
public void start(Model model, ActionListener<Boolean> listener) {
179-
init();
180-
doStart(model, listener);
182+
SubscribableListener.newForked(this::init)
183+
.<Boolean>andThen((doStartListener) -> doStart(model, doStartListener))
184+
.addListener(listener);
181185
}
182186

183187
@Override
@@ -189,8 +193,8 @@ protected void doStart(Model model, ActionListener<Boolean> listener) {
189193
listener.onResponse(true);
190194
}
191195

192-
private void init() {
193-
sender.start();
196+
private void init(ActionListener<Void> listener) {
197+
sender.startAsynchronously(listener);
194198
}
195199

196200
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@
7575
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure;
7676
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.providerAllowsTaskType;
7777

78+
/**
79+
* TODO we should remove AmazonBedrockService's dependency on SenderService. Bedrock leverages its own SDK with handles sending requests
80+
* and already implements rate limiting.
81+
*
82+
* https://github.com/elastic/ml-team/issues/1706
83+
*/
7884
public class AmazonBedrockService extends SenderService {
7985
public static final String NAME = "amazonbedrock";
8086
private static final String SERVICE_NAME = "Amazon Bedrock";

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public Factory(
7575

7676
public Sender createSender() {
7777
// ensure this is started
78-
bedrockRequestSender.start();
78+
bedrockRequestSender.startSynchronously();
7979
return bedrockRequestSender;
8080
}
8181
}
@@ -98,7 +98,13 @@ protected AmazonBedrockRequestSender(
9898
}
9999

100100
@Override
101-
public void start() {
101+
public void startAsynchronously(ActionListener<Void> listener) {
102+
103+
throw new UnsupportedOperationException("not implemented");
104+
}
105+
106+
@Override
107+
public void startSynchronously() {
102108
if (started.compareAndSet(false, true)) {
103109
// The manager must be started before the executor service. That way we guarantee that the http client
104110
// is ready prior to the service attempting to use the http client to send a request

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.ElasticsearchException;
1313
import org.elasticsearch.ExceptionsHelper;
1414
import org.elasticsearch.action.ActionListener;
15+
import org.elasticsearch.action.support.SubscribableListener;
1516
import org.elasticsearch.common.Strings;
1617
import org.elasticsearch.core.Nullable;
1718
import org.elasticsearch.core.TimeValue;
@@ -82,37 +83,33 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
8283
return;
8384
}
8485

85-
// ensure that the sender is initialized
86-
sender.start();
87-
88-
ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> {
89-
if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
90-
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
91-
listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity));
92-
} else {
93-
var errorMessage = Strings.format(
94-
"%s Received an invalid response type from the Elastic Inference Service: %s",
95-
FAILED_TO_RETRIEVE_MESSAGE,
96-
results.getClass().getSimpleName()
97-
);
98-
99-
logger.warn(errorMessage);
100-
listener.onFailure(new ElasticsearchException(errorMessage));
101-
}
102-
requestCompleteLatch.countDown();
103-
}, e -> {
86+
var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> {
10487
// unwrap because it's likely a retry exception
10588
var exception = ExceptionsHelper.unwrapCause(e);
10689

10790
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
108-
listener.onFailure(e);
109-
requestCompleteLatch.countDown();
91+
authModelListener.onFailure(e);
11092
});
11193

112-
var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
113-
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata);
94+
SubscribableListener.newForked(sender::startAsynchronously).<InferenceServiceResults>andThen((authListener) -> {
95+
var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
96+
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata);
97+
sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener);
98+
}).andThenApply(authResult -> {
99+
if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
100+
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
101+
return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity);
102+
}
103+
104+
var errorMessage = Strings.format(
105+
"%s Received an invalid response type from the Elastic Inference Service: %s",
106+
FAILED_TO_RETRIEVE_MESSAGE,
107+
authResult.getClass().getSimpleName()
108+
);
114109

115-
sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener);
110+
logger.warn(errorMessage);
111+
throw new ElasticsearchException(errorMessage);
112+
}).addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown));
116113
} catch (Exception e) {
117114
logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e));
118115
requestCompleteLatch.countDown();

0 commit comments

Comments
 (0)