Skip to content

Commit e630b76

Browse files
jonathan-buttnersarog
authored andcommitted
[ML] Ensuring only a single request executor object is created (elastic#133424) (elastic#133670)
* Ensuring only a single request executor thread is started * Reverting test changes * Update docs/changelog/133424.yaml
1 parent bf85f56 commit e630b76

File tree

5 files changed

+134
-63
lines changed

5 files changed

+134
-63
lines changed

docs/changelog/133424.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 133424
2+
summary: Ensuring only a single request executor object is created
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.action.support.ContextPreservingActionListener;
1313
import org.elasticsearch.cluster.service.ClusterService;
14-
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.core.Nullable;
1615
import org.elasticsearch.core.TimeValue;
1716
import org.elasticsearch.inference.InferenceServiceResults;
@@ -42,60 +41,63 @@ public class HttpRequestSender implements Sender {
4241
* A helper class for constructing a {@link HttpRequestSender}.
4342
*/
4443
public static class Factory {
45-
private final ServiceComponents serviceComponents;
46-
private final HttpClientManager httpClientManager;
47-
private final ClusterService clusterService;
48-
private final RequestSender requestSender;
44+
private final HttpRequestSender httpRequestSender;
4945

5046
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
51-
this.serviceComponents = Objects.requireNonNull(serviceComponents);
52-
this.httpClientManager = Objects.requireNonNull(httpClientManager);
53-
this.clusterService = Objects.requireNonNull(clusterService);
47+
Objects.requireNonNull(serviceComponents);
48+
Objects.requireNonNull(clusterService);
49+
Objects.requireNonNull(httpClientManager);
5450

55-
requestSender = new RetryingHttpSender(
56-
this.httpClientManager.getHttpClient(),
51+
var requestSender = new RetryingHttpSender(
52+
httpClientManager.getHttpClient(),
5753
serviceComponents.throttlerManager(),
5854
new RetrySettings(serviceComponents.settings(), clusterService),
5955
serviceComponents.threadPool()
6056
);
61-
}
6257

63-
public Sender createSender() {
64-
return new HttpRequestSender(
58+
var startCompleted = new CountDownLatch(1);
59+
var service = new RequestExecutorService(
6560
serviceComponents.threadPool(),
66-
httpClientManager,
67-
clusterService,
68-
serviceComponents.settings(),
61+
startCompleted,
62+
new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
6963
requestSender
7064
);
65+
66+
httpRequestSender = new HttpRequestSender(
67+
serviceComponents.threadPool(),
68+
httpClientManager,
69+
requestSender,
70+
service,
71+
startCompleted
72+
);
73+
}
74+
75+
public Sender createSender() {
76+
return httpRequestSender;
7177
}
7278
}
7379

7480
private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);
7581

7682
private final ThreadPool threadPool;
7783
private final HttpClientManager manager;
78-
private final RequestExecutor service;
7984
private final AtomicBoolean started = new AtomicBoolean(false);
80-
private final CountDownLatch startCompleted = new CountDownLatch(1);
8185
private final RequestSender requestSender;
86+
private final RequestExecutor service;
87+
private final CountDownLatch startCompleted;
8288

8389
private HttpRequestSender(
8490
ThreadPool threadPool,
8591
HttpClientManager httpClientManager,
86-
ClusterService clusterService,
87-
Settings settings,
88-
RequestSender requestSender
92+
RequestSender requestSender,
93+
RequestExecutor service,
94+
CountDownLatch startCompleted
8995
) {
9096
this.threadPool = Objects.requireNonNull(threadPool);
9197
this.manager = Objects.requireNonNull(httpClientManager);
9298
this.requestSender = Objects.requireNonNull(requestSender);
93-
service = new RequestExecutorService(
94-
threadPool,
95-
startCompleted,
96-
new RequestExecutorServiceSettings(settings, clusterService),
97-
requestSender
98-
);
99+
this.service = Objects.requireNonNull(service);
100+
this.startCompleted = Objects.requireNonNull(startCompleted);
99101
}
100102

101103
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.ElasticsearchException;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.cluster.service.ClusterService;
14-
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.core.TimeValue;
1615
import org.elasticsearch.inference.InferenceServiceResults;
1716
import org.elasticsearch.threadpool.ThreadPool;
@@ -38,31 +37,46 @@ public class AmazonBedrockRequestSender implements Sender {
3837

3938
public static class Factory {
4039
private final ServiceComponents serviceComponents;
41-
private final ClusterService clusterService;
40+
private final AmazonBedrockRequestExecutorService executorService;
41+
private final CountDownLatch startCompleted = new CountDownLatch(1);
42+
private final AmazonBedrockRequestSender bedrockRequestSender;
4243

4344
public Factory(ServiceComponents serviceComponents, ClusterService clusterService) {
44-
this.serviceComponents = Objects.requireNonNull(serviceComponents);
45-
this.clusterService = Objects.requireNonNull(clusterService);
46-
}
47-
48-
public Sender createSender() {
49-
var clientCache = new AmazonBedrockInferenceClientCache(
50-
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
51-
Clock.systemUTC()
45+
this(
46+
serviceComponents,
47+
clusterService,
48+
new AmazonBedrockExecuteOnlyRequestSender(
49+
new AmazonBedrockInferenceClientCache(
50+
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
51+
Clock.systemUTC()
52+
),
53+
serviceComponents.throttlerManager()
54+
)
5255
);
53-
return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager()));
5456
}
5557

56-
Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) {
57-
var sender = new AmazonBedrockRequestSender(
58+
public Factory(
59+
ServiceComponents serviceComponents,
60+
ClusterService clusterService,
61+
AmazonBedrockExecuteOnlyRequestSender requestSender
62+
) {
63+
this.serviceComponents = Objects.requireNonNull(serviceComponents);
64+
Objects.requireNonNull(clusterService);
65+
66+
executorService = new AmazonBedrockRequestExecutorService(
5867
serviceComponents.threadPool(),
59-
clusterService,
60-
serviceComponents.settings(),
61-
Objects.requireNonNull(requestSender)
68+
startCompleted,
69+
new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
70+
requestSender
6271
);
72+
73+
bedrockRequestSender = new AmazonBedrockRequestSender(serviceComponents.threadPool(), executorService, startCompleted);
74+
}
75+
76+
public Sender createSender() {
6377
// ensure this is started
64-
sender.start();
65-
return sender;
78+
bedrockRequestSender.start();
79+
return bedrockRequestSender;
6680
}
6781
}
6882

@@ -71,21 +85,16 @@ Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) {
7185
private final ThreadPool threadPool;
7286
private final AmazonBedrockRequestExecutorService executorService;
7387
private final AtomicBoolean started = new AtomicBoolean(false);
74-
private final CountDownLatch startCompleted = new CountDownLatch(1);
88+
private final CountDownLatch startCompleted;
7589

7690
protected AmazonBedrockRequestSender(
7791
ThreadPool threadPool,
78-
ClusterService clusterService,
79-
Settings settings,
80-
AmazonBedrockExecuteOnlyRequestSender requestSender
92+
AmazonBedrockRequestExecutorService executorService,
93+
CountDownLatch startCompleted
8194
) {
8295
this.threadPool = Objects.requireNonNull(threadPool);
83-
executorService = new AmazonBedrockRequestExecutorService(
84-
threadPool,
85-
startCompleted,
86-
new RequestExecutorServiceSettings(settings, clusterService),
87-
requestSender
88-
);
96+
this.executorService = Objects.requireNonNull(executorService);
97+
this.startCompleted = Objects.requireNonNull(startCompleted);
8998
}
9099

91100
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import static org.hamcrest.Matchers.hasSize;
5959
import static org.hamcrest.Matchers.instanceOf;
6060
import static org.hamcrest.Matchers.is;
61+
import static org.hamcrest.Matchers.sameInstance;
6162
import static org.mockito.ArgumentMatchers.anyString;
6263
import static org.mockito.Mockito.any;
6364
import static org.mockito.Mockito.doAnswer;
@@ -90,6 +91,27 @@ public void shutdown() throws IOException, InterruptedException {
9091
webServer.close();
9192
}
9293

94+
public void testCreateSender_ReturnsSameRequestExecutorInstance() {
95+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
96+
97+
var sender1 = createSender(senderFactory);
98+
var sender2 = createSender(senderFactory);
99+
100+
assertThat(sender1, instanceOf(HttpRequestSender.class));
101+
assertThat(sender2, instanceOf(HttpRequestSender.class));
102+
assertThat(sender1, sameInstance(sender2));
103+
}
104+
105+
public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
106+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
107+
108+
try (var sender = createSender(senderFactory)) {
109+
sender.start();
110+
sender.start();
111+
sender.start();
112+
}
113+
}
114+
93115
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
94116
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
95117

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
3838
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
3939
import static org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT;
40+
import static org.hamcrest.Matchers.instanceOf;
4041
import static org.hamcrest.Matchers.is;
42+
import static org.hamcrest.Matchers.sameInstance;
4143
import static org.mockito.Mockito.mock;
4244

4345
public class AmazonBedrockRequestSenderTests extends ESTestCase {
@@ -60,11 +62,37 @@ public void shutdown() throws IOException, InterruptedException {
6062
terminate(threadPool);
6163
}
6264

65+
public void testCreateSender_UsesTheSameInstanceForRequestExecutor() throws Exception {
66+
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
67+
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
68+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
69+
70+
var sender1 = createSender(senderFactory);
71+
var sender2 = createSender(senderFactory);
72+
73+
assertThat(sender1, instanceOf(AmazonBedrockRequestSender.class));
74+
assertThat(sender2, instanceOf(AmazonBedrockRequestSender.class));
75+
76+
assertThat(sender1, sameInstance(sender2));
77+
}
78+
79+
public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
80+
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
81+
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
82+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
83+
84+
try (var sender = createSender(senderFactory)) {
85+
sender.start();
86+
sender.start();
87+
sender.start();
88+
}
89+
}
90+
6391
public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws Exception {
64-
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
6592
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
6693
requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
67-
try (var sender = createSender(senderFactory, requestSender)) {
94+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
95+
try (var sender = createSender(senderFactory)) {
6896
sender.start();
6997

7098
var model = AmazonBedrockEmbeddingsModelTests.createModel(
@@ -92,10 +120,10 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws
92120
}
93121

94122
public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws Exception {
95-
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
96123
var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
97124
requestSender.enqueue(AmazonBedrockExecutorTests.getTestConverseResult("test response text"));
98-
try (var sender = createSender(senderFactory, requestSender)) {
125+
var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
126+
try (var sender = createSender(senderFactory)) {
99127
sender.start();
100128

101129
var model = AmazonBedrockChatCompletionModelTests.createModel(
@@ -116,14 +144,19 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws
116144
}
117145
}
118146

119-
public static AmazonBedrockRequestSender.Factory createSenderFactory(ThreadPool threadPool, Settings settings) {
147+
public static AmazonBedrockRequestSender.Factory createSenderFactory(
148+
ThreadPool threadPool,
149+
Settings settings,
150+
AmazonBedrockMockExecuteRequestSender requestSender
151+
) {
120152
return new AmazonBedrockRequestSender.Factory(
121153
ServiceComponentsTests.createWithSettings(threadPool, settings),
122-
mockClusterServiceEmpty()
154+
mockClusterServiceEmpty(),
155+
requestSender
123156
);
124157
}
125158

126-
public static Sender createSender(AmazonBedrockRequestSender.Factory factory, AmazonBedrockExecuteOnlyRequestSender requestSender) {
127-
return factory.createSender(requestSender);
159+
public static Sender createSender(AmazonBedrockRequestSender.Factory factory) {
160+
return factory.createSender();
128161
}
129162
}

0 commit comments

Comments
 (0)