Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,39 @@ private HttpRequestSender(
}

/**
* Start various internal services. This is required before sending requests.
* Star various internal services asynchronously. This is required before sending requests.
*/
public void start() {
@Override
public void startAsynchronously(ActionListener<Void> listener) {
if (started.compareAndSet(false, true)) {
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> startInternal(listener));
} else {
listener.onResponse(null);
}
}

private void startInternal(ActionListener<Void> listener) {
try {
// The manager must be started before the executor service. That way we guarantee that the http client
// is ready prior to the service attempting to use the http client to send a request
manager.start();
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(service::start);
waitForStartToComplete();
listener.onResponse(null);
} catch (Exception ex) {
listener.onFailure(ex);
}
}

/**
* Start various internal services. This is required before sending requests.
*
* NOTE: This method blocks until the startup is complete.
*/
@Override
public void startSynchronously() {
if (started.compareAndSet(false, true)) {
startInternal(ActionListener.noop());
Copy link
Contributor

@DonalEvans DonalEvans Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this cause any exception thrown in startInternal() to be ignored when doing a synchronous start? Also, do we need to make sure that we always call waitForStartToComplete() before returning from this method? If someone calls startAsynchronously() then another thread immediately calls startSynchronously(), the second call will return immediately (because we already set started to true) but the sender won't actually have started yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll make those changes.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import java.io.Closeable;

public interface Sender extends Closeable {
void start();
void startSynchronously();

void startAsynchronously(ActionListener<Void> listener);

void send(
RequestManager requestCreator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.IOUtils;
Expand Down Expand Up @@ -73,10 +74,11 @@ public void infer(
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService);
init();
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
doInfer(model, inferenceInput, taskSettings, timeout, listener);
SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((inferListener) -> {
var resolvedInferenceTimeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService);
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
doInfer(model, inferenceInput, taskSettings, resolvedInferenceTimeout, inferListener);
}).addListener(listener);
}

private static InferenceInputs createInput(
Expand Down Expand Up @@ -121,8 +123,9 @@ public void unifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
init();
doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener);
SubscribableListener.newForked(this::init).<InferenceServiceResults>andThen((completionInferListener) -> {
doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, completionInferListener);
}).addListener(listener);
}

@Override
Expand All @@ -135,16 +138,16 @@ public void chunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
init();

ValidationException validationException = new ValidationException();
validateInputType(inputType, model, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
SubscribableListener.newForked(this::init).<List<ChunkedInference>>andThen((chunkedInferListener) -> {
ValidationException validationException = new ValidationException();
validateInputType(inputType, model, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener);
}).addListener(listener);
}

protected abstract void doInfer(
Expand Down Expand Up @@ -176,8 +179,9 @@ protected abstract void doChunkedInfer(
);

public void start(Model model, ActionListener<Boolean> listener) {
init();
doStart(model, listener);
SubscribableListener.newForked(this::init)
.<Boolean>andThen((doStartListener) -> doStart(model, doStartListener))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of calling doStart() here? It seems to be a no-op that just immediately returns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that it can be overridden by child classes. In reality I don't think any actually override it yet. The Elasticsearch integration does use it but that doesn't extend from SenderService.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks for the explanation

.addListener(listener);
}

@Override
Expand All @@ -189,8 +193,8 @@ protected void doStart(Model model, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

private void init() {
sender.start();
private void init(ActionListener<Void> listener) {
sender.startAsynchronously(listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public Factory(

public Sender createSender() {
// ensure this is started
bedrockRequestSender.start();
bedrockRequestSender.startSynchronously();
return bedrockRequestSender;
}
}
Expand All @@ -97,8 +97,17 @@ protected AmazonBedrockRequestSender(
this.startCompleted = Objects.requireNonNull(startCompleted);
}

/**
* TODO implement this functionality to ensure that we don't block node bootups
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting bedrock is going to take a little more work. Probably best to do this in a separate PR because this one is already 50 files 😬

* See: https://github.com/elastic/ml-team/issues/1701
*/
@Override
public void start() {
public void startAsynchronously(ActionListener<Void> listener) {
throw new UnsupportedOperationException("not implemented");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth wrapping this throw in a check on the value of started? If the sender has already been started, then calling startAsynchronously() should have no effect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think in that situation we should still throw. It would be a bug if we're ever calling that method for AmazonBedrockRequestSender.

}

@Override
public void startSynchronously() {
if (started.compareAndSet(false, true)) {
// The manager must be started before the executor service. That way we guarantee that the http client
// is ready prior to the service attempting to use the http client to send a request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -82,37 +83,33 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
return;
}

// ensure that the sender is initialized
sender.start();

ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> {
if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity));
} else {
var errorMessage = Strings.format(
"%s Received an invalid response type from the Elastic Inference Service: %s",
FAILED_TO_RETRIEVE_MESSAGE,
results.getClass().getSimpleName()
);

logger.warn(errorMessage);
listener.onFailure(new ElasticsearchException(errorMessage));
}
requestCompleteLatch.countDown();
}, e -> {
var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> {
// unwrap because it's likely a retry exception
var exception = ExceptionsHelper.unwrapCause(e);

logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
listener.onFailure(e);
requestCompleteLatch.countDown();
authModelListener.onFailure(e);
});

var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata);
SubscribableListener.newForked(sender::startAsynchronously).<InferenceServiceResults>andThen((authListener) -> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we're doing an async start and then once that completes we do the rest of the functionality as normal.

var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata);
sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener);
}).andThenApply(authResult -> {
if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity);
}

var errorMessage = Strings.format(
"%s Received an invalid response type from the Elastic Inference Service: %s",
FAILED_TO_RETRIEVE_MESSAGE,
authResult.getClass().getSimpleName()
);

sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener);
logger.warn(errorMessage);
throw new ElasticsearchException(errorMessage);
}).addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown));
} catch (Exception e) {
logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e));
requestCompleteLatch.countDown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,17 @@ public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

try (var sender = createSender(senderFactory)) {
sender.start();
sender.start();
sender.start();
sender.startSynchronously();
sender.startSynchronously();
sender.startSynchronously();
Comment on lines +117 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add some tests for the startAsynchronously() method, since it's a distinct implementation from startSynchronously(). Also, a test that calling startAsynchronously() followed immediately by startSynchronously() behaves the way we expect would be good.

}
}

public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

try (var sender = createSender(senderFactory)) {
sender.start();
sender.startSynchronously();

String responseJson = """
{
Expand Down Expand Up @@ -167,7 +167,7 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce
var senderFactory = createSenderFactory(clientManager, threadRef);

try (var sender = createSender(senderFactory)) {
sender.start();
sender.startSynchronously();

String responseJson = """
{
Expand Down Expand Up @@ -240,7 +240,7 @@ public void testHttpRequestSender_Throws_WhenATimeoutOccurs() throws Exception {

try (var sender = senderFactory.createSender()) {
assertThat(sender, instanceOf(HttpRequestSender.class));
sender.start();
sender.startSynchronously();

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
sender.send(RequestManagerTests.createMock(), new EmbeddingsInput(List.of(), null), TimeValue.timeValueNanos(1), listener);
Expand All @@ -263,7 +263,7 @@ public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws
);

try (var sender = senderFactory.createSender()) {
sender.start();
sender.startSynchronously();

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
sender.send(RequestManagerTests.createMock(), new EmbeddingsInput(List.of(), null), TimeValue.timeValueNanos(1), listener);
Expand All @@ -286,7 +286,7 @@ public void testSendWithoutQueuingWithTimeout_Throws_WhenATimeoutOccurs() throws
);

try (var sender = senderFactory.createSender()) {
sender.start();
sender.startSynchronously();

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
sender.sendWithoutQueuing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var sender = createSender(senderFactory)) {
sender.start();
sender.startSynchronously();

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import static org.elasticsearch.xpack.inference.Utils.mockClusterService;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -65,7 +67,7 @@ public void shutdown() throws IOException {
}

public void testStart_InitializesTheSender() throws IOException {
var sender = mock(Sender.class);
var sender = createMockSender();

var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);
Expand All @@ -75,7 +77,7 @@ public void testStart_InitializesTheSender() throws IOException {
service.start(mock(Model.class), listener);

listener.actionGet(TIMEOUT);
verify(sender, times(1)).start();
verify(sender, times(1)).startAsynchronously(any());
verify(factory, times(1)).createSender();
}

Expand All @@ -85,7 +87,7 @@ public void testStart_InitializesTheSender() throws IOException {
}

public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOException {
var sender = mock(Sender.class);
var sender = createMockSender();

var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);
Expand All @@ -95,11 +97,13 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep
service.start(mock(Model.class), listener);
listener.actionGet(TIMEOUT);

service.start(mock(Model.class), listener);
listener.actionGet(TIMEOUT);

PlainActionFuture<Boolean> listener2 = new PlainActionFuture<>();
service.start(mock(Model.class), listener2);
listener2.actionGet(TIMEOUT);

verify(factory, times(1)).createSender();
verify(sender, times(2)).start();
verify(sender, times(2)).startAsynchronously(any());
}

verify(sender, times(1)).close();
Expand All @@ -108,7 +112,8 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep
}

public void test_nullTimeoutUsesClusterSetting() throws IOException {
var sender = mock(Sender.class);
var sender = createMockSender();

var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);

Expand Down Expand Up @@ -147,7 +152,7 @@ protected void doInfer(
}

public void test_providedTimeoutPropagateProperly() throws IOException {
var sender = mock(Sender.class);
var sender = createMockSender();
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);

Expand Down Expand Up @@ -185,6 +190,18 @@ protected void doInfer(
}
}

public static Sender createMockSender() {
var sender = mock(Sender.class);
doAnswer(invocationOnMock -> {
ActionListener<Void> listener = invocationOnMock.getArgument(0);
listener.onResponse(null);
return Void.TYPE;
}).when(sender).startAsynchronously(any());

return sender;
}


private static class TestSenderService extends SenderService {
TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
Expand Down
Loading