Skip to content

Commit d35d306

Browse files
[ML] Ensuring only a single request executor object is created (#133424) (#133716)
* Ensuring only a single request executor thread is started * Reverting test changes * Update docs/changelog/133424.yaml (cherry picked from commit d5a9343) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java
1 parent dc2a93a commit d35d306

File tree

5 files changed

+134
-63
lines changed

5 files changed

+134
-63
lines changed

docs/changelog/133424.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 133424
2+
summary: Ensuring only a single request executor object is created
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.ElasticsearchException;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.cluster.service.ClusterService;
14-
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.core.TimeValue;
1615
import org.elasticsearch.inference.InferenceServiceResults;
1716
import org.elasticsearch.threadpool.ThreadPool;
@@ -38,31 +37,46 @@ public class AmazonBedrockRequestSender implements Sender {
3837

3938
public static class Factory {
4039
private final ServiceComponents serviceComponents;
41-
private final ClusterService clusterService;
40+
private final AmazonBedrockRequestExecutorService executorService;
41+
private final CountDownLatch startCompleted = new CountDownLatch(1);
42+
private final AmazonBedrockRequestSender bedrockRequestSender;
4243

4344
public Factory(ServiceComponents serviceComponents, ClusterService clusterService) {
44-
this.serviceComponents = Objects.requireNonNull(serviceComponents);
45-
this.clusterService = Objects.requireNonNull(clusterService);
46-
}
47-
48-
public Sender createSender() {
49-
var clientCache = new AmazonBedrockInferenceClientCache(
50-
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
51-
Clock.systemUTC()
45+
this(
46+
serviceComponents,
47+
clusterService,
48+
new AmazonBedrockExecuteOnlyRequestSender(
49+
new AmazonBedrockInferenceClientCache(
50+
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
51+
Clock.systemUTC()
52+
),
53+
serviceComponents.throttlerManager()
54+
)
5255
);
53-
return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager()));
5456
}
5557

56-
Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) {
57-
var sender = new AmazonBedrockRequestSender(
58+
public Factory(
59+
ServiceComponents serviceComponents,
60+
ClusterService clusterService,
61+
AmazonBedrockExecuteOnlyRequestSender requestSender
62+
) {
63+
this.serviceComponents = Objects.requireNonNull(serviceComponents);
64+
Objects.requireNonNull(clusterService);
65+
66+
executorService = new AmazonBedrockRequestExecutorService(
5867
serviceComponents.threadPool(),
59-
clusterService,
60-
serviceComponents.settings(),
61-
Objects.requireNonNull(requestSender)
68+
startCompleted,
69+
new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
70+
requestSender
6271
);
72+
73+
bedrockRequestSender = new AmazonBedrockRequestSender(serviceComponents.threadPool(), executorService, startCompleted);
74+
}
75+
76+
public Sender createSender() {
6377
// ensure this is started
64-
sender.start();
65-
return sender;
78+
bedrockRequestSender.start();
79+
return bedrockRequestSender;
6680
}
6781
}
6882

@@ -71,21 +85,16 @@ Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) {
7185
private final ThreadPool threadPool;
7286
private final AmazonBedrockRequestExecutorService executorService;
7387
private final AtomicBoolean started = new AtomicBoolean(false);
74-
private final CountDownLatch startCompleted = new CountDownLatch(1);
88+
private final CountDownLatch startCompleted;
7589

7690
protected AmazonBedrockRequestSender(
7791
ThreadPool threadPool,
78-
ClusterService clusterService,
79-
Settings settings,
80-
AmazonBedrockExecuteOnlyRequestSender requestSender
92+
AmazonBedrockRequestExecutorService executorService,
93+
CountDownLatch startCompleted
8194
) {
8295
this.threadPool = Objects.requireNonNull(threadPool);
83-
executorService = new AmazonBedrockRequestExecutorService(
84-
threadPool,
85-
startCompleted,
86-
new RequestExecutorServiceSettings(settings, clusterService),
87-
requestSender
88-
);
96+
this.executorService = Objects.requireNonNull(executorService);
97+
this.startCompleted = Objects.requireNonNull(startCompleted);
8998
}
9099

91100
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.action.support.ContextPreservingActionListener;
1313
import org.elasticsearch.cluster.service.ClusterService;
14-
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.core.Nullable;
1615
import org.elasticsearch.core.TimeValue;
1716
import org.elasticsearch.inference.InferenceServiceResults;
@@ -42,60 +41,63 @@ public class HttpRequestSender implements Sender {
4241
* A helper class for constructing a {@link HttpRequestSender}.
4342
*/
4443
public static class Factory {
45-
private final ServiceComponents serviceComponents;
46-
private final HttpClientManager httpClientManager;
47-
private final ClusterService clusterService;
48-
private final RequestSender requestSender;
44+
private final HttpRequestSender httpRequestSender;
4945

5046
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
51-
this.serviceComponents = Objects.requireNonNull(serviceComponents);
52-
this.httpClientManager = Objects.requireNonNull(httpClientManager);
53-
this.clusterService = Objects.requireNonNull(clusterService);
47+
Objects.requireNonNull(serviceComponents);
48+
Objects.requireNonNull(clusterService);
49+
Objects.requireNonNull(httpClientManager);
5450

55-
requestSender = new RetryingHttpSender(
56-
this.httpClientManager.getHttpClient(),
51+
var requestSender = new RetryingHttpSender(
52+
httpClientManager.getHttpClient(),
5753
serviceComponents.throttlerManager(),
5854
new RetrySettings(serviceComponents.settings(), clusterService),
5955
serviceComponents.threadPool()
6056
);
61-
}
6257

63-
public Sender createSender() {
64-
return new HttpRequestSender(
58+
var startCompleted = new CountDownLatch(1);
59+
var service = new RequestExecutorService(
6560
serviceComponents.threadPool(),
66-
httpClientManager,
67-
clusterService,
68-
serviceComponents.settings(),
61+
startCompleted,
62+
new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
6963
requestSender
7064
);
65+
66+
httpRequestSender = new HttpRequestSender(
67+
serviceComponents.threadPool(),
68+
httpClientManager,
69+
requestSender,
70+
service,
71+
startCompleted
72+
);
73+
}
74+
75+
public Sender createSender() {
76+
return httpRequestSender;
7177
}
7278
}
7379

7480
private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);
7581

7682
private final ThreadPool threadPool;
7783
private final HttpClientManager manager;
78-
private final RequestExecutor service;
7984
private final AtomicBoolean started = new AtomicBoolean(false);
80-
private final CountDownLatch startCompleted = new CountDownLatch(1);
8185
private final RequestSender requestSender;
86+
private final RequestExecutor service;
87+
private final CountDownLatch startCompleted;
8288

8389
private HttpRequestSender(
8490
ThreadPool threadPool,
8591
HttpClientManager httpClientManager,
86-
ClusterService clusterService,
87-
Settings settings,
88-
RequestSender requestSender
92+
RequestSender requestSender,
93+
RequestExecutor service,
94+
CountDownLatch startCompleted
8995
) {
9096
this.threadPool = Objects.requireNonNull(threadPool);
9197
this.manager = Objects.requireNonNull(httpClientManager);
9298
this.requestSender = Objects.requireNonNull(requestSender);
93-
service = new RequestExecutorService(
94-
threadPool,
95-
startCompleted,
96-
new RequestExecutorServiceSettings(settings, clusterService),
97-
requestSender
98-
);
99+
this.service = Objects.requireNonNull(service);
100+
this.startCompleted = Objects.requireNonNull(startCompleted);
99101
}
100102

101103
/**

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT;
3737
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
3838
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
39+
import static org.hamcrest.Matchers.instanceOf;
3940
import static org.hamcrest.Matchers.is;
41+
import static org.hamcrest.Matchers.sameInstance;
4042
import static org.mockito.Mockito.mock;
4143

4244
public class AmazonBedrockRequestSenderTests extends ESTestCase {
@@ -59,11 +61,37 @@ public void shutdown() throws IOException, InterruptedException {
5961
terminate(threadPool);
6062
}
6163

64+
public void testCreateSender_UsesTheSameInstanceForRequestExecutor() throws Exception {
65+
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
66+
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
67+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
68+
69+
var sender1 = createSender(senderFactory);
70+
var sender2 = createSender(senderFactory);
71+
72+
assertThat(sender1, instanceOf(AmazonBedrockRequestSender.class));
73+
assertThat(sender2, instanceOf(AmazonBedrockRequestSender.class));
74+
75+
assertThat(sender1, sameInstance(sender2));
76+
}
77+
78+
public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
79+
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
80+
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
81+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
82+
83+
try (var sender = createSender(senderFactory)) {
84+
sender.start();
85+
sender.start();
86+
sender.start();
87+
}
88+
}
89+
6290
public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws Exception {
63-
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
6491
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
6592
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
66-
try (var sender = createSender(senderFactory, requestSender)) {
93+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
94+
try (var sender = createSender(senderFactory)) {
6795
sender.start();
6896

6997
var model = AmazonBedrockEmbeddingsModelTests.createModel(
@@ -91,10 +119,10 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws
91119
}
92120

93121
public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws Exception {
94-
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
95122
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
96123
requestSender.enqueue(AmazonBedrockExecutorTests.getTestConverseResult("test response text"));
97-
try (var sender = createSender(senderFactory, requestSender)) {
124+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
125+
try (var sender = createSender(senderFactory)) {
98126
sender.start();
99127

100128
var model = AmazonBedrockChatCompletionModelTests.createModel(
@@ -115,14 +143,19 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws
115143
}
116144
}
117145

118-
public static AmazonBedrockRequestSender.Factory createSenderFactory(ThreadPool threadPool, Settings settings) {
146+
public static AmazonBedrockRequestSender.Factory createSenderFactory(
147+
ThreadPool threadPool,
148+
Settings settings,
149+
AmazonBedrockMockExecuteRequestSender requestSender
150+
) {
119151
return new AmazonBedrockRequestSender.Factory(
120152
ServiceComponentsTests.createWithSettings(threadPool, settings),
121-
mockClusterServiceEmpty()
153+
mockClusterServiceEmpty(),
154+
requestSender
122155
);
123156
}
124157

125-
public static Sender createSender(AmazonBedrockRequestSender.Factory factory, AmazonBedrockExecuteOnlyRequestSender requestSender) {
126-
return factory.createSender(requestSender);
158+
public static Sender createSender(AmazonBedrockRequestSender.Factory factory) {
159+
return factory.createSender();
127160
}
128161
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import static org.hamcrest.Matchers.hasSize;
5757
import static org.hamcrest.Matchers.instanceOf;
5858
import static org.hamcrest.Matchers.is;
59+
import static org.hamcrest.Matchers.sameInstance;
5960
import static org.mockito.ArgumentMatchers.anyString;
6061
import static org.mockito.Mockito.any;
6162
import static org.mockito.Mockito.doAnswer;
@@ -88,6 +89,27 @@ public void shutdown() throws IOException, InterruptedException {
8889
webServer.close();
8990
}
9091

92+
public void testCreateSender_ReturnsSameRequestExecutorInstance() {
93+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
94+
95+
var sender1 = createSender(senderFactory);
96+
var sender2 = createSender(senderFactory);
97+
98+
assertThat(sender1, instanceOf(HttpRequestSender.class));
99+
assertThat(sender2, instanceOf(HttpRequestSender.class));
100+
assertThat(sender1, sameInstance(sender2));
101+
}
102+
103+
public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
104+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
105+
106+
try (var sender = createSender(senderFactory)) {
107+
sender.start();
108+
sender.start();
109+
sender.start();
110+
}
111+
}
112+
91113
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
92114
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
93115

0 commit comments

Comments
 (0)