diff --git a/docs/changelog/133424.yaml b/docs/changelog/133424.yaml new file mode 100644 index 0000000000000..6b89c4ec44173 --- /dev/null +++ b/docs/changelog/133424.yaml @@ -0,0 +1,5 @@ +pr: 133424 +summary: Ensuring only a single request executor object is created +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java index c8e544c26f293..6c54a8b7a4daa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java @@ -11,7 +11,6 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; @@ -38,31 +37,46 @@ public class AmazonBedrockRequestSender implements Sender { public static class Factory { private final ServiceComponents serviceComponents; - private final ClusterService clusterService; + private final AmazonBedrockRequestExecutorService executorService; + private final CountDownLatch startCompleted = new CountDownLatch(1); + private final AmazonBedrockRequestSender bedrockRequestSender; public Factory(ServiceComponents serviceComponents, ClusterService clusterService) { - this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.clusterService = Objects.requireNonNull(clusterService); - } - - public Sender createSender() { - var clientCache = new AmazonBedrockInferenceClientCache( - (model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()), - Clock.systemUTC() + this( + serviceComponents, + clusterService, + new AmazonBedrockExecuteOnlyRequestSender( + new AmazonBedrockInferenceClientCache( + (model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()), + Clock.systemUTC() + ), + serviceComponents.throttlerManager() + ) ); - return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager())); } - Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) { - var sender = new AmazonBedrockRequestSender( + public Factory( + ServiceComponents serviceComponents, + ClusterService clusterService, + AmazonBedrockExecuteOnlyRequestSender requestSender + ) { + this.serviceComponents = Objects.requireNonNull(serviceComponents); + Objects.requireNonNull(clusterService); + + executorService = new AmazonBedrockRequestExecutorService( serviceComponents.threadPool(), - clusterService, - serviceComponents.settings(), - Objects.requireNonNull(requestSender) + startCompleted, + new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService), + requestSender ); + + bedrockRequestSender = new AmazonBedrockRequestSender(serviceComponents.threadPool(), executorService, startCompleted); + } + + public Sender createSender() { // ensure this is started - sender.start(); - return sender; + bedrockRequestSender.start(); + return bedrockRequestSender; } } @@ -71,21 +85,16 @@ Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) { private final ThreadPool threadPool; private final AmazonBedrockRequestExecutorService executorService; private final AtomicBoolean started = new AtomicBoolean(false); - private final CountDownLatch startCompleted = new CountDownLatch(1); + private final CountDownLatch startCompleted; protected AmazonBedrockRequestSender( ThreadPool threadPool, - ClusterService clusterService, - Settings settings, - AmazonBedrockExecuteOnlyRequestSender requestSender + AmazonBedrockRequestExecutorService executorService, + CountDownLatch startCompleted ) { this.threadPool = Objects.requireNonNull(threadPool); - executorService = new AmazonBedrockRequestExecutorService( - threadPool, - startCompleted, - new RequestExecutorServiceSettings(settings, clusterService), - requestSender - ); + this.executorService = Objects.requireNonNull(executorService); + this.startCompleted = Objects.requireNonNull(startCompleted); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java index 689c9e2ec8fc1..f870f997153a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; @@ -42,32 +41,39 @@ public class HttpRequestSender implements Sender { * A helper class for constructing a {@link HttpRequestSender}. */ public static class Factory { - private final ServiceComponents serviceComponents; - private final HttpClientManager httpClientManager; - private final ClusterService clusterService; - private final RequestSender requestSender; + private final HttpRequestSender httpRequestSender; public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) { - this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.httpClientManager = Objects.requireNonNull(httpClientManager); - this.clusterService = Objects.requireNonNull(clusterService); + Objects.requireNonNull(serviceComponents); + Objects.requireNonNull(clusterService); + Objects.requireNonNull(httpClientManager); - requestSender = new RetryingHttpSender( - this.httpClientManager.getHttpClient(), + var requestSender = new RetryingHttpSender( + httpClientManager.getHttpClient(), serviceComponents.throttlerManager(), new RetrySettings(serviceComponents.settings(), clusterService), serviceComponents.threadPool() ); - } - public Sender createSender() { - return new HttpRequestSender( + var startCompleted = new CountDownLatch(1); + var service = new RequestExecutorService( serviceComponents.threadPool(), - httpClientManager, - clusterService, - serviceComponents.settings(), + startCompleted, + new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService), requestSender ); + + httpRequestSender = new HttpRequestSender( + serviceComponents.threadPool(), + httpClientManager, + requestSender, + service, + startCompleted + ); + } + + public Sender createSender() { + return httpRequestSender; } } @@ -75,27 +81,23 @@ public Sender createSender() { private final ThreadPool threadPool; private final HttpClientManager manager; - private final RequestExecutor service; private final AtomicBoolean started = new AtomicBoolean(false); - private final CountDownLatch startCompleted = new CountDownLatch(1); private final RequestSender requestSender; + private final RequestExecutor service; + private final CountDownLatch startCompleted; private HttpRequestSender( ThreadPool threadPool, HttpClientManager httpClientManager, - ClusterService clusterService, - Settings settings, - RequestSender requestSender + RequestSender requestSender, + RequestExecutor service, + CountDownLatch startCompleted ) { this.threadPool = Objects.requireNonNull(threadPool); this.manager = Objects.requireNonNull(httpClientManager); this.requestSender = Objects.requireNonNull(requestSender); - service = new RequestExecutorService( - threadPool, - startCompleted, - new RequestExecutorServiceSettings(settings, clusterService), - requestSender - ); + this.service = Objects.requireNonNull(service); + this.startCompleted = Objects.requireNonNull(startCompleted); } /** diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java index a8f37aedcece3..b5f1e8f490e74 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java @@ -36,7 +36,9 @@ import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; import static org.mockito.Mockito.mock; public class AmazonBedrockRequestSenderTests extends ESTestCase { @@ -59,11 +61,37 @@ public void shutdown() throws IOException, InterruptedException { terminate(threadPool); } + public void testCreateSender_UsesTheSameInstanceForRequestExecutor() throws Exception { + var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class)); + requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT)); + var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender); + + var sender1 = createSender(senderFactory); + var sender2 = createSender(senderFactory); + + assertThat(sender1, instanceOf(AmazonBedrockRequestSender.class)); + assertThat(sender2, instanceOf(AmazonBedrockRequestSender.class)); + + assertThat(sender1, sameInstance(sender2)); + } + + public void testCreateSender_CanCallStartMultipleTimes() throws Exception { + var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class)); + requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT)); + var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender); + + try (var sender = createSender(senderFactory)) { + sender.start(); + sender.start(); + sender.start(); + } + } + public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws Exception { - var senderFactory = createSenderFactory(threadPool, Settings.EMPTY); var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class)); requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT)); - try (var sender = createSender(senderFactory, requestSender)) { + var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender); + try (var sender = createSender(senderFactory)) { sender.start(); var model = AmazonBedrockEmbeddingsModelTests.createModel( @@ -91,10 +119,10 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws } public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws Exception { - var senderFactory = createSenderFactory(threadPool, Settings.EMPTY); var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class)); requestSender.enqueue(AmazonBedrockExecutorTests.getTestConverseResult("test response text")); - try (var sender = createSender(senderFactory, requestSender)) { + var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender); + try (var sender = createSender(senderFactory)) { sender.start(); var model = AmazonBedrockChatCompletionModelTests.createModel( @@ -115,14 +143,19 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws } } - public static AmazonBedrockRequestSender.Factory createSenderFactory(ThreadPool threadPool, Settings settings) { + public static AmazonBedrockRequestSender.Factory createSenderFactory( + ThreadPool threadPool, + Settings settings, + AmazonBedrockMockExecuteRequestSender requestSender + ) { return new AmazonBedrockRequestSender.Factory( ServiceComponentsTests.createWithSettings(threadPool, settings), - mockClusterServiceEmpty() + mockClusterServiceEmpty(), + requestSender ); } - public static Sender createSender(AmazonBedrockRequestSender.Factory factory, AmazonBedrockExecuteOnlyRequestSender requestSender) { - return factory.createSender(requestSender); + public static Sender createSender(AmazonBedrockRequestSender.Factory factory) { + return factory.createSender(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 5a4953553e92e..c1b0247196f3f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -56,6 +56,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; @@ -88,6 +89,27 @@ public void shutdown() throws IOException, InterruptedException { webServer.close(); } + public void testCreateSender_ReturnsSameRequestExecutorInstance() { + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); + + var sender1 = createSender(senderFactory); + var sender2 = createSender(senderFactory); + + assertThat(sender1, instanceOf(HttpRequestSender.class)); + assertThat(sender2, instanceOf(HttpRequestSender.class)); + assertThat(sender1, sameInstance(sender2)); + } + + public void testCreateSender_CanCallStartMultipleTimes() throws Exception { + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); + + try (var sender = createSender(senderFactory)) { + sender.start(); + sender.start(); + sender.start(); + } + } + public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());