Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -77,6 +79,7 @@ public Sender createSender() {
}
}

private static final Logger logger = LogManager.getLogger(HttpRequestSender.class);
private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);

private final ThreadPool threadPool;
Expand Down Expand Up @@ -133,8 +136,14 @@ private void startInternal(ActionListener<Void> listener) {
@Override
public void startSynchronously() {
if (started.compareAndSet(false, true)) {
startInternal(ActionListener.noop());
ActionListener<Void> listener = ActionListener.wrap(unused -> {}, exception -> {
logger.error("Http sender failed to start", exception);
ExceptionsHelper.maybeDieOnAnotherThread(exception);
});
startInternal(listener);
}
// Handle the case where start*() was already called and this would return immediately because the started flag is already true
waitForStartToComplete();
Comment on lines +165 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we need to do something similar for async calls, since if two async calls come in one after the other, the second one will complete immediately even if the first one hasn't finished starting the sender yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I tried to come up with a solution that would avoid having to do spin up a thread to then call the waitForStartToComplete since most of the time it will simply return.

}

private void waitForStartToComplete() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ protected AmazonBedrockRequestSender(

@Override
public void startAsynchronously(ActionListener<Void> listener) {

throw new UnsupportedOperationException("not implemented");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth wrapping this throw in a check on the value of started? If the sender has already been started, then calling startAsynchronously() should have no effect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think in that situation we should still throw. It would be a bug if we're ever calling that method for AmazonBedrockRequestSender.

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
Expand All @@ -61,6 +62,7 @@
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -111,6 +113,91 @@ public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
}
}

public void testStart_ThrowsException_WhenAnErrorOccurs() throws IOException {
var mockManager = mock(HttpClientManager.class);
when(mockManager.getHttpClient()).thenReturn(mock(HttpClient.class));
doThrow(new Error("failed")).when(mockManager).start();

var senderFactory = new HttpRequestSender.Factory(
ServiceComponentsTests.createWithEmptySettings(threadPool),
mockManager,
mockClusterServiceEmpty()
);

try (var sender = senderFactory.createSender()) {
// Checking for both exception types because there's a race condition between the Error being thrown on a separate thread
// and the startCompleted latch timing out waiting for the start to complete
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HttpRequestSender.startInternal(), only catches and handles Exception, so any Error thrown in that method will always escape and cause the listener to not be invoked, meaning that the maybeDieOnAnotherThread() call never happens, and neither does the waitForStartToComplete() call in startSynchronously(), so we wouldn't ever expect to see the IllegalStateException get thrown from waitForStartToComplete().

If I change to test to use an IllegalArgumentException wrapping an Error, then the listener is invoked and we always get the IllegalStateException thrown from startSynchronously() due to timing out waiting for the sender to start. However, with that change, the test fails due to the error being thrown in another thread. I don't know how to tell a test to expect an exception to be thrown in another thread, but it looks like CloseFollowerIndexIT.wrapUncaughtExceptionHandler() might be trying to solve the same problem.

I wonder if we need to rethrow the Error at all in the case where we catch an Exception with an Error as one of its causes, or just log it and allow the waitForStartToComplete() call to inevitably time out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch.

I wonder if we need to rethrow the Error at all in the case where we catch an Exception with an Error as one of its causes, or just log it and allow the waitForStartToComplete() call to inevitably time out?

Yeah I think I'm going to just log it and rely on the waitForStartToComplete(). After we refactor bedrock, I'm pretty sure we can remove the startSynchronously() all together or just use it for tests.

var exception = expectThrowsAnyOf(List.of(Error.class, IllegalStateException.class), sender::startSynchronously);

if (exception instanceof Error) {
assertThat(exception.getMessage(), is("failed"));
} else {
// IllegalStateException can be thrown if the startCompleted latch times out waiting for the start to complete
assertThat(exception.getMessage(), is("Http sender startup did not complete in time"));
}
}
}

public void testStart_ThrowsExceptionWaitingForStartToComplete() throws IOException {
var mockManager = mock(HttpClientManager.class);
when(mockManager.getHttpClient()).thenReturn(mock(HttpClient.class));
// This won't get rethrown because it is not an Error
doThrow(new IllegalArgumentException("failed")).when(mockManager).start();

var senderFactory = new HttpRequestSender.Factory(
ServiceComponentsTests.createWithEmptySettings(threadPool),
mockManager,
mockClusterServiceEmpty()
);

try (var sender = senderFactory.createSender()) {
var exception = expectThrows(IllegalStateException.class, sender::startSynchronously);

assertThat(exception.getMessage(), is("Http sender startup did not complete in time"));
}
}

public void testCreateSender_CanCallStartAsyncMultipleTimes() throws Exception {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test and the one below it could be improved a little by verifying that no matter how many times we call startAsynchronously() or startSynchronously(), we only call HttpClientManager.start() once:

        var clientManagerSpy = spy(clientManager);
        var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManagerSpy, mockClusterServiceEmpty());
...
            for (int i = 0; i < asyncCalls; i++) {
                PlainActionFuture<Void> listener = listenerList.get(i);
                assertNull(listener.actionGet(TIMEOUT));
            }

            verify(clientManagerSpy, times(1)).start();

It would also be nice if we could verify that we're calling waitForStartToComplete() the expected number of times.

var asyncCalls = 3;
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

try (var sender = createSender(senderFactory)) {
var listenerList = new ArrayList<PlainActionFuture<Void>>();

for (int i = 0; i < asyncCalls; i++) {
PlainActionFuture<Void> listener = new PlainActionFuture<>();
listenerList.add(listener);
sender.startAsynchronously(listener);
}

for (int i = 0; i < asyncCalls; i++) {
PlainActionFuture<Void> listener = listenerList.get(i);
assertNull(listener.actionGet(TIMEOUT));
}
}
}

public void testCreateSender_CanCallStartAsyncAndSyncMultipleTimes() throws Exception {
var asyncCalls = 3;
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

try (var sender = createSender(senderFactory)) {
var listenerList = new ArrayList<PlainActionFuture<Void>>();

for (int i = 0; i < asyncCalls; i++) {
PlainActionFuture<Void> listener = new PlainActionFuture<>();
listenerList.add(listener);
sender.startAsynchronously(listener);
sender.startSynchronously();
}

for (int i = 0; i < asyncCalls; i++) {
PlainActionFuture<Void> listener = listenerList.get(i);
assertNull(listener.actionGet(TIMEOUT));
}
}
}

public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public void startSynchronously() {

@Override
public void startAsynchronously(ActionListener<Void> listener) {
listener.onResponse(null);
throw new UnsupportedOperationException("not supported");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,43 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
}
}

public void testGetAuthorization_OnResponseCalledOnce() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
var eisGatewayUrl = getUrl(webServer);
var logger = mock(Logger.class);
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger);

PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
ActionListener<ElasticInferenceServiceAuthorizationModel> onlyOnceListener = ActionListener.assertOnce(listener);
String responseJson = """
{
"models": [
{
"model_name": "model-a",
"task_types": ["embed/text/sparse", "chat"]
}
]
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var sender = senderFactory.createSender()) {
authHandler.getAuthorization(onlyOnceListener, sender);
authHandler.waitForAuthRequestCompletion(TIMEOUT);

var authResponse = listener.actionGet(TIMEOUT);
assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));
assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a")));
assertTrue(authResponse.isAuthorized());

var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
verify(logger, times(1)).debug(loggerArgsCaptor.capture());

var message = loggerArgsCaptor.getValue();
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
}
}

public void testGetAuthorization_InvalidResponse() throws IOException {
var senderMock = createMockSender();
var senderFactory = mock(HttpRequestSender.Factory.class);
Expand Down