Skip to content

Commit 6a986d8

Browse files
Using a setonce and adding a test
1 parent 6623daa commit 6a986d8

File tree

9 files changed

+82
-28
lines changed

9 files changed

+82
-28
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.util.Map;
3535
import java.util.concurrent.CountDownLatch;
3636
import java.util.concurrent.TimeUnit;
37+
import java.util.function.Consumer;
3738

3839
import static org.hamcrest.Matchers.anyOf;
3940
import static org.hamcrest.Matchers.equalTo;
@@ -341,31 +342,44 @@ protected Map<String, Object> infer(String modelId, List<String> input) throws I
341342
return inferInternal(endpoint, input, null, Map.of());
342343
}
343344

344-
protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
345+
protected Deque<ServerSentEvent> streamInferOnMockService(
346+
String modelId,
347+
TaskType taskType,
348+
List<String> input,
349+
@Nullable Consumer<Response> responseConsumerCallback
350+
) throws Exception {
345351
var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId);
346-
return callAsync(endpoint, input);
352+
return callAsync(endpoint, input, responseConsumerCallback);
347353
}
348354

349-
protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List<String> input)
350-
throws Exception {
355+
protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(
356+
String modelId,
357+
TaskType taskType,
358+
List<String> input,
359+
@Nullable Consumer<Response> responseConsumerCallback
360+
) throws Exception {
351361
var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId);
352-
return callAsyncUnified(endpoint, input, "user");
362+
return callAsyncUnified(endpoint, input, "user", responseConsumerCallback);
353363
}
354364

355-
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
365+
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input, @Nullable Consumer<Response> responseConsumerCallback)
366+
throws Exception {
356367
var request = new Request("POST", endpoint);
357368
request.setJsonEntity(jsonBody(input, null));
358369

359-
return execAsyncCall(request);
370+
return execAsyncCall(request, responseConsumerCallback);
360371
}
361372

362-
private Deque<ServerSentEvent> execAsyncCall(Request request) throws Exception {
373+
private Deque<ServerSentEvent> execAsyncCall(Request request, @Nullable Consumer<Response> responseConsumerCallback) throws Exception {
363374
var responseConsumer = new AsyncInferenceResponseConsumer();
364375
request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
365376
var latch = new CountDownLatch(1);
366377
client().performRequestAsync(request, new ResponseListener() {
367378
@Override
368379
public void onSuccess(Response response) {
380+
if (responseConsumerCallback != null) {
381+
responseConsumerCallback.accept(response);
382+
}
369383
latch.countDown();
370384
}
371385

@@ -378,11 +392,16 @@ public void onFailure(Exception exception) {
378392
return responseConsumer.events();
379393
}
380394

381-
private Deque<ServerSentEvent> callAsyncUnified(String endpoint, List<String> input, String role) throws Exception {
395+
private Deque<ServerSentEvent> callAsyncUnified(
396+
String endpoint,
397+
List<String> input,
398+
String role,
399+
@Nullable Consumer<Response> responseConsumerCallback
400+
) throws Exception {
382401
var request = new Request("POST", endpoint);
383402

384403
request.setJsonEntity(createUnifiedJsonBody(input, role));
385-
return execAsyncCall(request);
404+
return execAsyncCall(request, responseConsumerCallback);
386405
}
387406

388407
private String createUnifiedJsonBody(List<String> input, String role) throws IOException {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.xpack.inference;
1111

1212
import org.apache.http.util.EntityUtils;
13+
import org.elasticsearch.client.Response;
1314
import org.elasticsearch.client.ResponseException;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.common.settings.Settings;
@@ -28,6 +29,7 @@
2829
import java.util.Map;
2930
import java.util.Objects;
3031
import java.util.Set;
32+
import java.util.function.Consumer;
3133
import java.util.function.Function;
3234
import java.util.stream.IntStream;
3335
import java.util.stream.Stream;
@@ -37,9 +39,15 @@
3739
import static org.hamcrest.Matchers.equalTo;
3840
import static org.hamcrest.Matchers.equalToIgnoringCase;
3941
import static org.hamcrest.Matchers.hasSize;
42+
import static org.hamcrest.Matchers.is;
4043

4144
public class InferenceCrudIT extends InferenceBaseRestTest {
4245

46+
private static final Consumer<Response> VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER = (r) -> assertThat(
47+
r.getHeader("X-elastic-product"),
48+
is("Elasticsearch")
49+
);
50+
4351
@SuppressWarnings("unchecked")
4452
public void testCRUD() throws IOException {
4553
for (int i = 0; i < 5; i++) {
@@ -442,7 +450,7 @@ public void testUnsupportedStream() throws Exception {
442450
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));
443451

444452
try {
445-
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()));
453+
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()), null);
446454
assertThat(events.size(), equalTo(2));
447455
events.forEach(event -> {
448456
switch (event.name()) {
@@ -469,7 +477,7 @@ public void testSupportedStream() throws Exception {
469477

470478
var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList();
471479
try {
472-
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input);
480+
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);
473481

474482
var expectedResponses = Stream.concat(
475483
input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"),
@@ -496,7 +504,7 @@ public void testUnifiedCompletionInference() throws Exception {
496504

497505
var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList();
498506
try {
499-
var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input);
507+
var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);
500508
var expectedResponses = expectedResultsIterator(input);
501509
assertThat(events.size(), equalTo((input.size() + 1) * 2));
502510
events.forEach(event -> {

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
133133
var publisher = new RandomPublisher(requestCount, withError);
134134
var inferenceServiceResults = new StreamingInferenceServiceResults(publisher);
135135
var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher());
136-
new ServerSentEventsRestActionListener(channel, threadPool.get()).onResponse(inferenceResponse);
136+
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
137137
}
138138
}, new RestHandler() {
139139
@Override
@@ -143,7 +143,7 @@ public List<Route> routes() {
143143

144144
@Override
145145
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
146-
new ServerSentEventsRestActionListener(channel, threadPool.get()).onFailure(expectedException);
146+
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
147147
}
148148
}, new RestHandler() {
149149
@Override
@@ -154,7 +154,7 @@ public List<Route> routes() {
154154
@Override
155155
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
156156
var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults());
157-
new ServerSentEventsRestActionListener(channel, threadPool.get()).onResponse(inferenceResponse);
157+
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
158158
}
159159
});
160160
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.elasticsearch.search.rank.RankDoc;
4444
import org.elasticsearch.threadpool.ExecutorBuilder;
4545
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
46+
import org.elasticsearch.threadpool.ThreadPool;
4647
import org.elasticsearch.xcontent.ParseField;
4748
import org.elasticsearch.xpack.core.ClientHelper;
4849
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
@@ -154,6 +155,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
154155
private final SetOnce<HttpRequestSender.Factory> httpFactory = new SetOnce<>();
155156
private final SetOnce<AmazonBedrockRequestSender.Factory> amazonBedrockFactory = new SetOnce<>();
156157
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();
158+
// This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it
159+
// not being initialized yet
160+
private final SetOnce<ThreadPool> threadPoolSetOnce = new SetOnce<>();
157161
private final SetOnce<ElasticInferenceServiceComponents> elasticInferenceServiceComponents = new SetOnce<>();
158162
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
159163
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
@@ -201,7 +205,7 @@ public List<RestHandler> getRestHandlers(
201205

202206
var availableRestActions = List.of(
203207
new RestInferenceAction(),
204-
new RestStreamInferenceAction(serviceComponents.get().threadPool()),
208+
new RestStreamInferenceAction(threadPoolSetOnce),
205209
new RestGetInferenceModelAction(),
206210
new RestPutInferenceModelAction(),
207211
new RestUpdateInferenceModelAction(),
@@ -210,7 +214,7 @@ public List<RestHandler> getRestHandlers(
210214
new RestGetInferenceServicesAction()
211215
);
212216
List<RestHandler> conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled()
213-
? List.of(new RestUnifiedCompletionInferenceAction(serviceComponents.get().threadPool()))
217+
? List.of(new RestUnifiedCompletionInferenceAction(threadPoolSetOnce))
214218
: List.of();
215219

216220
return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList();
@@ -221,6 +225,7 @@ public Collection<?> createComponents(PluginServices services) {
221225
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
222226
var truncator = new Truncator(settings, services.clusterService());
223227
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
228+
threadPoolSetOnce.set(services.threadPool());
224229

225230
var httpClientManager = HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager);
226231
var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.rest;
99

10+
import org.apache.lucene.util.SetOnce;
1011
import org.elasticsearch.action.ActionListener;
1112
import org.elasticsearch.rest.RestChannel;
1213
import org.elasticsearch.rest.Scope;
@@ -23,9 +24,9 @@
2324

2425
@ServerlessScope(Scope.PUBLIC)
2526
public class RestStreamInferenceAction extends BaseInferenceAction {
26-
private final ThreadPool threadPool;
27+
private final SetOnce<ThreadPool> threadPool;
2728

28-
public RestStreamInferenceAction(ThreadPool threadPool) {
29+
public RestStreamInferenceAction(SetOnce<ThreadPool> threadPool) {
2930
super();
3031
this.threadPool = Objects.requireNonNull(threadPool);
3132
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.rest;
99

10+
import org.apache.lucene.util.SetOnce;
1011
import org.elasticsearch.client.internal.node.NodeClient;
1112
import org.elasticsearch.rest.BaseRestHandler;
1213
import org.elasticsearch.rest.RestRequest;
@@ -25,9 +26,9 @@
2526

2627
@ServerlessScope(Scope.PUBLIC)
2728
public class RestUnifiedCompletionInferenceAction extends BaseRestHandler {
28-
private final ThreadPool threadPool;
29+
private final SetOnce<ThreadPool> threadPool;
2930

30-
public RestUnifiedCompletionInferenceAction(ThreadPool threadPool) {
31+
public RestUnifiedCompletionInferenceAction(SetOnce<ThreadPool> threadPool) {
3132
super();
3233
this.threadPool = Objects.requireNonNull(threadPool);
3334
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.lucene.util.BytesRef;
13+
import org.apache.lucene.util.SetOnce;
1314
import org.elasticsearch.ElasticsearchException;
1415
import org.elasticsearch.ExceptionsHelper;
1516
import org.elasticsearch.action.ActionListener;
@@ -58,7 +59,7 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
5859
private final AtomicBoolean isLastPart = new AtomicBoolean(false);
5960
private final RestChannel channel;
6061
private final ToXContent.Params params;
61-
private final ThreadPool threadPool;
62+
private final SetOnce<ThreadPool> threadPool;
6263

6364
/**
6465
* A listener for the first part of the next entry to become available for transmission.
@@ -70,11 +71,11 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
7071
*/
7172
private ActionListener<ChunkedRestResponseBodyPart> nextBodyPartListener;
7273

73-
public ServerSentEventsRestActionListener(RestChannel channel, ThreadPool threadPool) {
74+
public ServerSentEventsRestActionListener(RestChannel channel, SetOnce<ThreadPool> threadPool) {
7475
this(channel, channel.request(), threadPool);
7576
}
7677

77-
public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, ThreadPool threadPool) {
78+
public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, SetOnce<ThreadPool> threadPool) {
7879
this.channel = channel;
7980
this.params = params;
8081
this.threadPool = Objects.requireNonNull(threadPool);
@@ -123,7 +124,7 @@ private void initializeStream(InferenceAction.Response response) {
123124

124125
nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext(
125126
chunkedResponseBodyActionListener,
126-
threadPool.getThreadContext()
127+
threadPool.get().getThreadContext()
127128
);
128129

129130
// subscribe will call onSubscribe, which requests the first chunk

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import org.elasticsearch.test.rest.FakeRestRequest;
1414
import org.elasticsearch.test.rest.RestActionTestCase;
1515
import org.elasticsearch.threadpool.TestThreadPool;
16+
import org.elasticsearch.threadpool.ThreadPool;
1617
import org.elasticsearch.xcontent.XContentType;
1718
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
19+
import org.junit.After;
1820
import org.junit.Before;
1921

2022
import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse;
@@ -23,10 +25,18 @@
2325
import static org.hamcrest.Matchers.instanceOf;
2426

2527
public class RestStreamInferenceActionTests extends RestActionTestCase {
28+
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();
2629

2730
@Before
2831
public void setUpAction() {
29-
controller().registerHandler(new RestStreamInferenceAction(new TestThreadPool(getTestName())));
32+
threadPool.set(new TestThreadPool(getTestName()));
33+
controller().registerHandler(new RestStreamInferenceAction(threadPool));
34+
}
35+
36+
@After
37+
public void tearDownAction() {
38+
terminate(threadPool.get());
39+
3040
}
3141

3242
public void testStreamIsTrue() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import org.elasticsearch.test.rest.FakeRestRequest;
1919
import org.elasticsearch.test.rest.RestActionTestCase;
2020
import org.elasticsearch.threadpool.TestThreadPool;
21+
import org.elasticsearch.threadpool.ThreadPool;
2122
import org.elasticsearch.xcontent.XContentType;
2223
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
24+
import org.junit.After;
2325
import org.junit.Before;
2426

2527
import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse;
@@ -28,10 +30,17 @@
2830
import static org.hamcrest.Matchers.instanceOf;
2931

3032
public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase {
33+
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();
3134

3235
@Before
3336
public void setUpAction() {
34-
controller().registerHandler(new RestUnifiedCompletionInferenceAction(new TestThreadPool(getTestName())));
37+
threadPool.set(new TestThreadPool(getTestName()));
38+
controller().registerHandler(new RestUnifiedCompletionInferenceAction(threadPool));
39+
}
40+
41+
@After
42+
public void tearDownAction() {
43+
terminate(threadPool.get());
3544
}
3645

3746
public void testStreamIsTrue() {

0 commit comments

Comments
 (0)