Skip to content

Commit baf50fe

Browse files
[ML] Ensuring only a single request executor object is created (#133424) (#133721)
* Ensuring only a single request executor thread is started * Reverting test changes * Update docs/changelog/133424.yaml (cherry picked from commit d5a9343) # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java
1 parent ebc7f0c commit baf50fe

File tree

5 files changed

+127
-64
lines changed

5 files changed

+127
-64
lines changed

docs/changelog/133424.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 133424
2+
summary: Ensuring only a single request executor object is created
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.elasticsearch.ElasticsearchException;
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.cluster.service.ClusterService;
13-
import org.elasticsearch.common.settings.Settings;
1413
import org.elasticsearch.core.TimeValue;
1514
import org.elasticsearch.inference.InferenceServiceResults;
1615
import org.elasticsearch.threadpool.ThreadPool;
@@ -35,31 +34,46 @@ public class AmazonBedrockRequestSender implements Sender {
3534

3635
public static class Factory {
3736
private final ServiceComponents serviceComponents;
38-
private final ClusterService clusterService;
37+
private final AmazonBedrockRequestExecutorService executorService;
38+
private final CountDownLatch startCompleted = new CountDownLatch(1);
39+
private final AmazonBedrockRequestSender bedrockRequestSender;
3940

4041
public Factory(ServiceComponents serviceComponents, ClusterService clusterService) {
41-
this.serviceComponents = Objects.requireNonNull(serviceComponents);
42-
this.clusterService = Objects.requireNonNull(clusterService);
43-
}
44-
45-
public Sender createSender() {
46-
var clientCache = new AmazonBedrockInferenceClientCache(
47-
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
48-
Clock.systemUTC()
42+
this(
43+
serviceComponents,
44+
clusterService,
45+
new AmazonBedrockExecuteOnlyRequestSender(
46+
new AmazonBedrockInferenceClientCache(
47+
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
48+
Clock.systemUTC()
49+
),
50+
serviceComponents.throttlerManager()
51+
)
4952
);
50-
return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager()));
5153
}
5254

53-
Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) {
54-
var sender = new AmazonBedrockRequestSender(
55+
public Factory(
56+
ServiceComponents serviceComponents,
57+
ClusterService clusterService,
58+
AmazonBedrockExecuteOnlyRequestSender requestSender
59+
) {
60+
this.serviceComponents = Objects.requireNonNull(serviceComponents);
61+
Objects.requireNonNull(clusterService);
62+
63+
executorService = new AmazonBedrockRequestExecutorService(
5564
serviceComponents.threadPool(),
56-
clusterService,
57-
serviceComponents.settings(),
58-
Objects.requireNonNull(requestSender)
65+
startCompleted,
66+
new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
67+
requestSender
5968
);
69+
70+
bedrockRequestSender = new AmazonBedrockRequestSender(serviceComponents.threadPool(), executorService, startCompleted);
71+
}
72+
73+
public Sender createSender() {
6074
// ensure this is started
61-
sender.start();
62-
return sender;
75+
bedrockRequestSender.start();
76+
return bedrockRequestSender;
6377
}
6478
}
6579

@@ -68,21 +82,16 @@ Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) {
6882
private final ThreadPool threadPool;
6983
private final AmazonBedrockRequestExecutorService executorService;
7084
private final AtomicBoolean started = new AtomicBoolean(false);
71-
private final CountDownLatch startCompleted = new CountDownLatch(1);
85+
private final CountDownLatch startCompleted;
7286

7387
protected AmazonBedrockRequestSender(
7488
ThreadPool threadPool,
75-
ClusterService clusterService,
76-
Settings settings,
77-
AmazonBedrockExecuteOnlyRequestSender requestSender
89+
AmazonBedrockRequestExecutorService executorService,
90+
CountDownLatch startCompleted
7891
) {
7992
this.threadPool = Objects.requireNonNull(threadPool);
80-
executorService = new AmazonBedrockRequestExecutorService(
81-
threadPool,
82-
startCompleted,
83-
new RequestExecutorServiceSettings(settings, clusterService),
84-
requestSender
85-
);
93+
this.executorService = Objects.requireNonNull(executorService);
94+
this.startCompleted = Objects.requireNonNull(startCompleted);
8695
}
8796

8897
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.cluster.service.ClusterService;
12-
import org.elasticsearch.common.settings.Settings;
1312
import org.elasticsearch.core.Nullable;
1413
import org.elasticsearch.core.TimeValue;
1514
import org.elasticsearch.inference.InferenceServiceResults;
1615
import org.elasticsearch.threadpool.ThreadPool;
1716
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
1817
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
19-
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
2018
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
2119
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
2220
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@@ -38,58 +36,54 @@ public class HttpRequestSender implements Sender {
3836
* A helper class for constructing a {@link HttpRequestSender}.
3937
*/
4038
public static class Factory {
41-
private final ServiceComponents serviceComponents;
42-
private final HttpClientManager httpClientManager;
43-
private final ClusterService clusterService;
44-
private final RequestSender requestSender;
39+
private final HttpRequestSender httpRequestSender;
4540

4641
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
47-
this.serviceComponents = Objects.requireNonNull(serviceComponents);
48-
this.httpClientManager = Objects.requireNonNull(httpClientManager);
49-
this.clusterService = Objects.requireNonNull(clusterService);
42+
Objects.requireNonNull(serviceComponents);
43+
Objects.requireNonNull(clusterService);
44+
Objects.requireNonNull(httpClientManager);
5045

51-
requestSender = new RetryingHttpSender(
52-
this.httpClientManager.getHttpClient(),
46+
var requestSender = new RetryingHttpSender(
47+
httpClientManager.getHttpClient(),
5348
serviceComponents.throttlerManager(),
5449
new RetrySettings(serviceComponents.settings(), clusterService),
5550
serviceComponents.threadPool()
5651
);
57-
}
5852

59-
public Sender createSender() {
60-
return new HttpRequestSender(
53+
var startCompleted = new CountDownLatch(1);
54+
var service = new RequestExecutorService(
6155
serviceComponents.threadPool(),
62-
httpClientManager,
63-
clusterService,
64-
serviceComponents.settings(),
56+
startCompleted,
57+
new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
6558
requestSender
6659
);
60+
61+
httpRequestSender = new HttpRequestSender(serviceComponents.threadPool(), httpClientManager, service, startCompleted);
62+
}
63+
64+
public Sender createSender() {
65+
return httpRequestSender;
6766
}
6867
}
6968

7069
private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);
7170

7271
private final ThreadPool threadPool;
7372
private final HttpClientManager manager;
74-
private final RequestExecutor service;
7573
private final AtomicBoolean started = new AtomicBoolean(false);
76-
private final CountDownLatch startCompleted = new CountDownLatch(1);
74+
private final RequestExecutor service;
75+
private final CountDownLatch startCompleted;
7776

7877
private HttpRequestSender(
7978
ThreadPool threadPool,
8079
HttpClientManager httpClientManager,
81-
ClusterService clusterService,
82-
Settings settings,
83-
RequestSender requestSender
80+
RequestExecutor service,
81+
CountDownLatch startCompleted
8482
) {
8583
this.threadPool = Objects.requireNonNull(threadPool);
8684
this.manager = Objects.requireNonNull(httpClientManager);
87-
service = new RequestExecutorService(
88-
threadPool,
89-
startCompleted,
90-
new RequestExecutorServiceSettings(settings, clusterService),
91-
requestSender
92-
);
85+
this.service = Objects.requireNonNull(service);
86+
this.startCompleted = Objects.requireNonNull(startCompleted);
9387
}
9488

9589
/**

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT;
3636
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
3737
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
38+
import static org.hamcrest.Matchers.instanceOf;
3839
import static org.hamcrest.Matchers.is;
40+
import static org.hamcrest.Matchers.sameInstance;
3941
import static org.mockito.Mockito.mock;
4042

4143
public class AmazonBedrockRequestSenderTests extends ESTestCase {
@@ -58,11 +60,37 @@ public void shutdown() throws IOException, InterruptedException {
5860
terminate(threadPool);
5961
}
6062

63+
public void testCreateSender_UsesTheSameInstanceForRequestExecutor() throws Exception {
64+
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
65+
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
66+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
67+
68+
var sender1 = createSender(senderFactory);
69+
var sender2 = createSender(senderFactory);
70+
71+
assertThat(sender1, instanceOf(AmazonBedrockRequestSender.class));
72+
assertThat(sender2, instanceOf(AmazonBedrockRequestSender.class));
73+
74+
assertThat(sender1, sameInstance(sender2));
75+
}
76+
77+
public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
78+
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
79+
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
80+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
81+
82+
try (var sender = createSender(senderFactory)) {
83+
sender.start();
84+
sender.start();
85+
sender.start();
86+
}
87+
}
88+
6189
public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws Exception {
62-
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
6390
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
6491
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
65-
try (var sender = createSender(senderFactory, requestSender)) {
92+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
93+
try (var sender = createSender(senderFactory)) {
6694
sender.start();
6795

6896
var model = AmazonBedrockEmbeddingsModelTests.createModel(
@@ -90,10 +118,10 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws
90118
}
91119

92120
public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws Exception {
93-
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
94121
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
95122
requestSender.enqueue(AmazonBedrockExecutorTests.getTestConverseResult("test response text"));
96-
try (var sender = createSender(senderFactory, requestSender)) {
123+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
124+
try (var sender = createSender(senderFactory)) {
97125
sender.start();
98126

99127
var model = AmazonBedrockChatCompletionModelTests.createModel(
@@ -114,14 +142,19 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws
114142
}
115143
}
116144

117-
public static AmazonBedrockRequestSender.Factory createSenderFactory(ThreadPool threadPool, Settings settings) {
145+
public static AmazonBedrockRequestSender.Factory createSenderFactory(
146+
ThreadPool threadPool,
147+
Settings settings,
148+
AmazonBedrockMockExecuteRequestSender requestSender
149+
) {
118150
return new AmazonBedrockRequestSender.Factory(
119151
ServiceComponentsTests.createWithSettings(threadPool, settings),
120-
mockClusterServiceEmpty()
152+
mockClusterServiceEmpty(),
153+
requestSender
121154
);
122155
}
123156

124-
public static Sender createSender(AmazonBedrockRequestSender.Factory factory, AmazonBedrockExecuteOnlyRequestSender requestSender) {
125-
return factory.createSender(requestSender);
157+
public static Sender createSender(AmazonBedrockRequestSender.Factory factory) {
158+
return factory.createSender();
126159
}
127160
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import static org.hamcrest.Matchers.hasSize;
4646
import static org.hamcrest.Matchers.instanceOf;
4747
import static org.hamcrest.Matchers.is;
48+
import static org.hamcrest.Matchers.sameInstance;
4849
import static org.mockito.ArgumentMatchers.anyString;
4950
import static org.mockito.Mockito.any;
5051
import static org.mockito.Mockito.doAnswer;
@@ -77,6 +78,27 @@ public void shutdown() throws IOException, InterruptedException {
7778
webServer.close();
7879
}
7980

81+
public void testCreateSender_ReturnsSameRequestExecutorInstance() {
82+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
83+
84+
var sender1 = createSender(senderFactory);
85+
var sender2 = createSender(senderFactory);
86+
87+
assertThat(sender1, instanceOf(HttpRequestSender.class));
88+
assertThat(sender2, instanceOf(HttpRequestSender.class));
89+
assertThat(sender1, sameInstance(sender2));
90+
}
91+
92+
public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
93+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
94+
95+
try (var sender = createSender(senderFactory)) {
96+
sender.start();
97+
sender.start();
98+
sender.start();
99+
}
100+
}
101+
80102
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
81103
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
82104

0 commit comments

Comments
 (0)