Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
Expand Down Expand Up @@ -42,6 +43,7 @@ public class HttpRequestSender implements Sender {
*/
public static class Factory {
private final HttpRequestSender httpRequestSender;
private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);

public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
Objects.requireNonNull(serviceComponents);
Expand All @@ -68,7 +70,8 @@ public Factory(ServiceComponents serviceComponents, HttpClientManager httpClient
httpClientManager,
requestSender,
service,
startCompleted
startCompleted,
START_COMPLETED_WAIT_TIME
);
}

Expand All @@ -77,45 +80,95 @@ public Sender createSender() {
}
}

private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);
private static final Logger logger = LogManager.getLogger(HttpRequestSender.class);

private final ThreadPool threadPool;
private final HttpClientManager manager;
private final AtomicBoolean started = new AtomicBoolean(false);
private final AtomicBoolean startInitiated = new AtomicBoolean(false);
private final AtomicBoolean startCompleted = new AtomicBoolean(false);
private final RequestSender requestSender;
private final RequestExecutor service;
private final CountDownLatch startCompleted;
private final CountDownLatch startCompletedLatch;
private final TimeValue startCompletedWaitTime;

private HttpRequestSender(
// Visible for testing
protected HttpRequestSender(
ThreadPool threadPool,
HttpClientManager httpClientManager,
RequestSender requestSender,
RequestExecutor service,
CountDownLatch startCompleted
CountDownLatch startCompletedLatch,
TimeValue startCompletedWaitTime
) {
this.threadPool = Objects.requireNonNull(threadPool);
this.manager = Objects.requireNonNull(httpClientManager);
this.requestSender = Objects.requireNonNull(requestSender);
this.service = Objects.requireNonNull(service);
this.startCompleted = Objects.requireNonNull(startCompleted);
this.startCompletedLatch = Objects.requireNonNull(startCompletedLatch);
this.startCompletedWaitTime = Objects.requireNonNull(startCompletedWaitTime);
}

/**
* Start various internal services. This is required before sending requests.
* Start various internal services asynchronously. This is required before sending requests.
*/
public void start() {
if (started.compareAndSet(false, true)) {
@Override
public void startAsynchronously(ActionListener<Void> listener) {
if (startInitiated.compareAndSet(false, true)) {
var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> startInternal(preservedListener));
} else if (startCompleted.get() == false) {
var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
// wait on another thread so we don't potential block a transport thread
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> waitForStartToCompleteWithListener(preservedListener));
} else {
listener.onResponse(null);
}
}

private void startInternal(ActionListener<Void> listener) {
try {
// The manager must be started before the executor service. That way we guarantee that the http client
// is ready prior to the service attempting to use the http client to send a request
manager.start();
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(service::start);
waitForStartToComplete();
startCompleted.set(true);
listener.onResponse(null);
} catch (Exception ex) {
listener.onFailure(ex);
}
}

private void waitForStartToCompleteWithListener(ActionListener<Void> listener) {
try {
waitForStartToComplete();
listener.onResponse(null);
} catch (Exception e) {
listener.onFailure(e);
}
}

/**
* Start various internal services. This is required before sending requests.
*
* NOTE: This method blocks until the startup is complete.
*/
@Override
public void startSynchronously() {
if (startInitiated.compareAndSet(false, true)) {
ActionListener<Void> listener = ActionListener.wrap(
unused -> {},
exception -> logger.error("Http sender failed to start", exception)
);
startInternal(listener);
}
// Handle the case where start*() was already called and this would return immediately because the started flag is already true
waitForStartToComplete();
Comment on lines +165 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we need to do something similar for async calls, since if two async calls come in one after the other, the second one will complete immediately even if the first one hasn't finished starting the sender yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I tried to come up with a solution that would avoid having to do spin up a thread to then call the waitForStartToComplete since most of the time it will simply return.

}

private void waitForStartToComplete() {
try {
if (startCompleted.await(START_COMPLETED_WAIT_TIME.getSeconds(), TimeUnit.SECONDS) == false) {
if (startCompletedLatch.await(startCompletedWaitTime.getMillis(), TimeUnit.MILLISECONDS) == false) {
throw new IllegalStateException("Http sender startup did not complete in time");
}
} catch (InterruptedException e) {
Expand Down Expand Up @@ -145,7 +198,7 @@ public void send(
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
assert started.get() : "call start() before sending a request";
assert startInitiated.get() : "call start() before sending a request";
waitForStartToComplete();
service.execute(requestCreator, inferenceInputs, timeout, listener);
}
Expand All @@ -167,7 +220,7 @@ public void sendWithoutQueuing(
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
assert started.get() : "call start() before sending a request";
assert startInitiated.get() : "call start() before sending a request";
waitForStartToComplete();

var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import java.io.Closeable;

public interface Sender extends Closeable {
void start();
void startSynchronously();

void startAsynchronously(ActionListener<Void> listener);

void send(
RequestManager requestCreator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.IOUtils;
Expand Down Expand Up @@ -73,10 +74,11 @@ public void infer(
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService);
init();
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
doInfer(model, inferenceInput, taskSettings, timeout, listener);
SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((inferListener) -> {
var resolvedInferenceTimeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService);
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
doInfer(model, inferenceInput, taskSettings, resolvedInferenceTimeout, inferListener);
}).addListener(listener);
}

private static InferenceInputs createInput(
Expand Down Expand Up @@ -121,8 +123,9 @@ public void unifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
init();
doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener);
SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((completionInferListener) -> {
doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, completionInferListener);
}).addListener(listener);
}

@Override
Expand All @@ -135,16 +138,16 @@ public void chunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
init();

ValidationException validationException = new ValidationException();
validateInputType(inputType, model, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
SubscribableListener.newForked(this::init).<List<ChunkedInference>>andThen((chunkedInferListener) -> {
ValidationException validationException = new ValidationException();
validateInputType(inputType, model, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener);
}).addListener(listener);
}

protected abstract void doInfer(
Expand Down Expand Up @@ -176,8 +179,9 @@ protected abstract void doChunkedInfer(
);

public void start(Model model, ActionListener<Boolean> listener) {
init();
doStart(model, listener);
SubscribableListener.newForked(this::init)
.<Boolean>andThen((doStartListener) -> doStart(model, doStartListener))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of calling doStart() here? It seems to be a no-op that just immediately returns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that it can be overridden by child classes. In reality I don't think any actually override it yet. The Elasticsearch integration does use it but that doesn't extend from SenderService.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks for the explanation

.addListener(listener);
}

@Override
Expand All @@ -189,8 +193,8 @@ protected void doStart(Model model, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

private void init() {
sender.start();
private void init(ActionListener<Void> listener) {
sender.startAsynchronously(listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.providerAllowsTaskType;

/**
* TODO we should remove AmazonBedrockService's dependency on SenderService. Bedrock leverages its own SDK with handles sending requests
* and already implements rate limiting.
*
* https://github.com/elastic/ml-team/issues/1706
*/
public class AmazonBedrockService extends SenderService {
public static final String NAME = "amazonbedrock";
private static final String SERVICE_NAME = "Amazon Bedrock";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public Factory(

public Sender createSender() {
// ensure this is started
bedrockRequestSender.start();
bedrockRequestSender.startSynchronously();
return bedrockRequestSender;
}
}
Expand All @@ -98,7 +98,13 @@ protected AmazonBedrockRequestSender(
}

@Override
public void start() {
public void startAsynchronously(ActionListener<Void> listener) {

throw new UnsupportedOperationException("not implemented");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth wrapping this throw in a check on the value of started? If the sender has already been started, then calling startAsynchronously() should have no effect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think in that situation we should still throw. It would be a bug if we're ever calling that method for AmazonBedrockRequestSender.

}

@Override
public void startSynchronously() {
if (started.compareAndSet(false, true)) {
// The manager must be started before the executor service. That way we guarantee that the http client
// is ready prior to the service attempting to use the http client to send a request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -82,37 +83,33 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
return;
}

// ensure that the sender is initialized
sender.start();

ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> {
if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity));
} else {
var errorMessage = Strings.format(
"%s Received an invalid response type from the Elastic Inference Service: %s",
FAILED_TO_RETRIEVE_MESSAGE,
results.getClass().getSimpleName()
);

logger.warn(errorMessage);
listener.onFailure(new ElasticsearchException(errorMessage));
}
requestCompleteLatch.countDown();
}, e -> {
var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> {
// unwrap because it's likely a retry exception
var exception = ExceptionsHelper.unwrapCause(e);

logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
listener.onFailure(e);
requestCompleteLatch.countDown();
authModelListener.onFailure(e);
});

var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata);
SubscribableListener.newForked(sender::startAsynchronously).<InferenceServiceResults>andThen((authListener) -> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we're doing an async start and then once that completes we do the rest of the functionality as normal.

var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata);
sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener);
}).andThenApply(authResult -> {
if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity);
}

var errorMessage = Strings.format(
"%s Received an invalid response type from the Elastic Inference Service: %s",
FAILED_TO_RETRIEVE_MESSAGE,
authResult.getClass().getSimpleName()
);

sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener);
logger.warn(errorMessage);
throw new ElasticsearchException(errorMessage);
}).addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown));
} catch (Exception e) {
logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e));
requestCompleteLatch.countDown();
Expand Down
Loading