Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/133424.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 133424
summary: Ensuring only a single request executor object is created
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -38,31 +37,46 @@ public class AmazonBedrockRequestSender implements Sender {

public static class Factory {
private final ServiceComponents serviceComponents;
private final ClusterService clusterService;
private final AmazonBedrockRequestExecutorService executorService;
private final CountDownLatch startCompleted = new CountDownLatch(1);
private final AmazonBedrockRequestSender bedrockRequestSender;

public Factory(ServiceComponents serviceComponents, ClusterService clusterService) {
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.clusterService = Objects.requireNonNull(clusterService);
}

public Sender createSender() {
var clientCache = new AmazonBedrockInferenceClientCache(
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
Clock.systemUTC()
this(
serviceComponents,
clusterService,
new AmazonBedrockExecuteOnlyRequestSender(
new AmazonBedrockInferenceClientCache(
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
Clock.systemUTC()
),
serviceComponents.throttlerManager()
)
);
return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager()));
}

Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) {
var sender = new AmazonBedrockRequestSender(
public Factory(
ServiceComponents serviceComponents,
ClusterService clusterService,
AmazonBedrockExecuteOnlyRequestSender requestSender
) {
this.serviceComponents = Objects.requireNonNull(serviceComponents);
Objects.requireNonNull(clusterService);

executorService = new AmazonBedrockRequestExecutorService(
serviceComponents.threadPool(),
clusterService,
serviceComponents.settings(),
Objects.requireNonNull(requestSender)
startCompleted,
new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
requestSender
);

bedrockRequestSender = new AmazonBedrockRequestSender(serviceComponents.threadPool(), executorService, startCompleted);
}

public Sender createSender() {
// ensure this is started
sender.start();
return sender;
bedrockRequestSender.start();
return bedrockRequestSender;
}
}

Expand All @@ -71,21 +85,16 @@ Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) {
private final ThreadPool threadPool;
private final AmazonBedrockRequestExecutorService executorService;
private final AtomicBoolean started = new AtomicBoolean(false);
private final CountDownLatch startCompleted = new CountDownLatch(1);
private final CountDownLatch startCompleted;

protected AmazonBedrockRequestSender(
ThreadPool threadPool,
ClusterService clusterService,
Settings settings,
AmazonBedrockExecuteOnlyRequestSender requestSender
AmazonBedrockRequestExecutorService executorService,
CountDownLatch startCompleted
) {
this.threadPool = Objects.requireNonNull(threadPool);
executorService = new AmazonBedrockRequestExecutorService(
threadPool,
startCompleted,
new RequestExecutorServiceSettings(settings, clusterService),
requestSender
);
this.executorService = Objects.requireNonNull(executorService);
this.startCompleted = Objects.requireNonNull(startCompleted);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
Expand Down Expand Up @@ -42,60 +41,63 @@ public class HttpRequestSender implements Sender {
* A helper class for constructing a {@link HttpRequestSender}.
*/
public static class Factory {
private final ServiceComponents serviceComponents;
private final HttpClientManager httpClientManager;
private final ClusterService clusterService;
private final RequestSender requestSender;
private final HttpRequestSender httpRequestSender;

public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.httpClientManager = Objects.requireNonNull(httpClientManager);
this.clusterService = Objects.requireNonNull(clusterService);
Objects.requireNonNull(serviceComponents);
Objects.requireNonNull(clusterService);
Objects.requireNonNull(httpClientManager);

requestSender = new RetryingHttpSender(
this.httpClientManager.getHttpClient(),
var requestSender = new RetryingHttpSender(
httpClientManager.getHttpClient(),
serviceComponents.throttlerManager(),
new RetrySettings(serviceComponents.settings(), clusterService),
serviceComponents.threadPool()
);
}

public Sender createSender() {
return new HttpRequestSender(
var startCompleted = new CountDownLatch(1);
var service = new RequestExecutorService(
serviceComponents.threadPool(),
httpClientManager,
clusterService,
serviceComponents.settings(),
startCompleted,
new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
requestSender
);

httpRequestSender = new HttpRequestSender(
serviceComponents.threadPool(),
httpClientManager,
requestSender,
service,
startCompleted
);
}

public Sender createSender() {
return httpRequestSender;
}
}

private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);

private final ThreadPool threadPool;
private final HttpClientManager manager;
private final RequestExecutor service;
private final AtomicBoolean started = new AtomicBoolean(false);
private final CountDownLatch startCompleted = new CountDownLatch(1);
private final RequestSender requestSender;
private final RequestExecutor service;
private final CountDownLatch startCompleted;

private HttpRequestSender(
ThreadPool threadPool,
HttpClientManager httpClientManager,
ClusterService clusterService,
Settings settings,
RequestSender requestSender
RequestSender requestSender,
RequestExecutor service,
CountDownLatch startCompleted
) {
this.threadPool = Objects.requireNonNull(threadPool);
this.manager = Objects.requireNonNull(httpClientManager);
this.requestSender = Objects.requireNonNull(requestSender);
service = new RequestExecutorService(
threadPool,
startCompleted,
new RequestExecutorServiceSettings(settings, clusterService),
requestSender
);
this.service = Objects.requireNonNull(service);
this.startCompleted = Objects.requireNonNull(startCompleted);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.Mockito.mock;

public class AmazonBedrockRequestSenderTests extends ESTestCase {
Expand All @@ -59,11 +61,37 @@ public void shutdown() throws IOException, InterruptedException {
terminate(threadPool);
}

public void testCreateSender_UsesTheSameInstanceForRequestExecutor() throws Exception {
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);

var sender1 = createSender(senderFactory);
var sender2 = createSender(senderFactory);

assertThat(sender1, instanceOf(AmazonBedrockRequestSender.class));
assertThat(sender2, instanceOf(AmazonBedrockRequestSender.class));

assertThat(sender1, sameInstance(sender2));
}

public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);

try (var sender = createSender(senderFactory)) {
sender.start();
sender.start();
sender.start();
}
}

public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws Exception {
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
try (var sender = createSender(senderFactory, requestSender)) {
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
try (var sender = createSender(senderFactory)) {
sender.start();

var model = AmazonBedrockEmbeddingsModelTests.createModel(
Expand Down Expand Up @@ -91,10 +119,10 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws
}

public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws Exception {
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
requestSender.enqueue(AmazonBedrockExecutorTests.getTestConverseResult("test response text"));
try (var sender = createSender(senderFactory, requestSender)) {
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
try (var sender = createSender(senderFactory)) {
sender.start();

var model = AmazonBedrockChatCompletionModelTests.createModel(
Expand All @@ -115,14 +143,19 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws
}
}

public static AmazonBedrockRequestSender.Factory createSenderFactory(ThreadPool threadPool, Settings settings) {
public static AmazonBedrockRequestSender.Factory createSenderFactory(
ThreadPool threadPool,
Settings settings,
AmazonBedrockMockExecuteRequestSender requestSender
) {
return new AmazonBedrockRequestSender.Factory(
ServiceComponentsTests.createWithSettings(threadPool, settings),
mockClusterServiceEmpty()
mockClusterServiceEmpty(),
requestSender
);
}

public static Sender createSender(AmazonBedrockRequestSender.Factory factory, AmazonBedrockExecuteOnlyRequestSender requestSender) {
return factory.createSender(requestSender);
public static Sender createSender(AmazonBedrockRequestSender.Factory factory) {
return factory.createSender();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -88,6 +89,27 @@ public void shutdown() throws IOException, InterruptedException {
webServer.close();
}

public void testCreateSender_ReturnsSameRequestExecutorInstance() {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

var sender1 = createSender(senderFactory);
var sender2 = createSender(senderFactory);

assertThat(sender1, instanceOf(HttpRequestSender.class));
assertThat(sender2, instanceOf(HttpRequestSender.class));
assertThat(sender1, sameInstance(sender2));
}

public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

try (var sender = createSender(senderFactory)) {
sender.start();
sender.start();
sender.start();
}
}

public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

Expand Down