Skip to content

Commit 12f4357

Browse files
[8.x] [ML] Fix loss of context in the inference API for streaming APIs (elastic#118999) (elastic#119218)
* [ML] Fix loss of context in the inference API for streaming APIs (elastic#118999) * Adding context preserving fix * Update docs/changelog/118999.yaml * Update docs/changelog/118999.yaml * Using a setonce and adding a test * Updating the changelog (cherry picked from commit 7ba3cb9) # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java * Removing assert
1 parent 94de73a commit 12f4357

File tree

10 files changed

+132
-26
lines changed

10 files changed

+132
-26
lines changed

docs/changelog/118999.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 118999
2+
summary: Fix loss of context in the inference API for streaming APIs
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 119000

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.util.Map;
3535
import java.util.concurrent.CountDownLatch;
3636
import java.util.concurrent.TimeUnit;
37+
import java.util.function.Consumer;
3738

3839
import static org.hamcrest.Matchers.anyOf;
3940
import static org.hamcrest.Matchers.equalTo;
@@ -341,31 +342,44 @@ protected Map<String, Object> infer(String modelId, List<String> input) throws I
341342
return inferInternal(endpoint, input, null, Map.of());
342343
}
343344

344-
protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
345+
protected Deque<ServerSentEvent> streamInferOnMockService(
346+
String modelId,
347+
TaskType taskType,
348+
List<String> input,
349+
@Nullable Consumer<Response> responseConsumerCallback
350+
) throws Exception {
345351
var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId);
346-
return callAsync(endpoint, input);
352+
return callAsync(endpoint, input, responseConsumerCallback);
347353
}
348354

349-
protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List<String> input)
350-
throws Exception {
355+
protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(
356+
String modelId,
357+
TaskType taskType,
358+
List<String> input,
359+
@Nullable Consumer<Response> responseConsumerCallback
360+
) throws Exception {
351361
var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId);
352-
return callAsyncUnified(endpoint, input, "user");
362+
return callAsyncUnified(endpoint, input, "user", responseConsumerCallback);
353363
}
354364

355-
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
365+
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input, @Nullable Consumer<Response> responseConsumerCallback)
366+
throws Exception {
356367
var request = new Request("POST", endpoint);
357368
request.setJsonEntity(jsonBody(input, null));
358369

359-
return execAsyncCall(request);
370+
return execAsyncCall(request, responseConsumerCallback);
360371
}
361372

362-
private Deque<ServerSentEvent> execAsyncCall(Request request) throws Exception {
373+
private Deque<ServerSentEvent> execAsyncCall(Request request, @Nullable Consumer<Response> responseConsumerCallback) throws Exception {
363374
var responseConsumer = new AsyncInferenceResponseConsumer();
364375
request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
365376
var latch = new CountDownLatch(1);
366377
client().performRequestAsync(request, new ResponseListener() {
367378
@Override
368379
public void onSuccess(Response response) {
380+
if (responseConsumerCallback != null) {
381+
responseConsumerCallback.accept(response);
382+
}
369383
latch.countDown();
370384
}
371385

@@ -378,11 +392,16 @@ public void onFailure(Exception exception) {
378392
return responseConsumer.events();
379393
}
380394

381-
private Deque<ServerSentEvent> callAsyncUnified(String endpoint, List<String> input, String role) throws Exception {
395+
private Deque<ServerSentEvent> callAsyncUnified(
396+
String endpoint,
397+
List<String> input,
398+
String role,
399+
@Nullable Consumer<Response> responseConsumerCallback
400+
) throws Exception {
382401
var request = new Request("POST", endpoint);
383402

384403
request.setJsonEntity(createUnifiedJsonBody(input, role));
385-
return execAsyncCall(request);
404+
return execAsyncCall(request, responseConsumerCallback);
386405
}
387406

388407
private String createUnifiedJsonBody(List<String> input, String role) throws IOException {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.xpack.inference;
1111

1212
import org.apache.http.util.EntityUtils;
13+
import org.elasticsearch.client.Response;
1314
import org.elasticsearch.client.ResponseException;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.common.settings.Settings;
@@ -25,6 +26,7 @@
2526
import java.util.Map;
2627
import java.util.Objects;
2728
import java.util.Set;
29+
import java.util.function.Consumer;
2830
import java.util.function.Function;
2931
import java.util.stream.IntStream;
3032
import java.util.stream.Stream;
@@ -34,9 +36,15 @@
3436
import static org.hamcrest.Matchers.equalTo;
3537
import static org.hamcrest.Matchers.equalToIgnoringCase;
3638
import static org.hamcrest.Matchers.hasSize;
39+
import static org.hamcrest.Matchers.is;
3740

3841
public class InferenceCrudIT extends InferenceBaseRestTest {
3942

43+
private static final Consumer<Response> VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER = (r) -> assertThat(
44+
r.getHeader("X-elastic-product"),
45+
is("Elasticsearch")
46+
);
47+
4048
@SuppressWarnings("unchecked")
4149
public void testCRUD() throws IOException {
4250
for (int i = 0; i < 5; i++) {
@@ -288,7 +296,7 @@ public void testUnsupportedStream() throws Exception {
288296
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));
289297

290298
try {
291-
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()));
299+
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()), null);
292300
assertThat(events.size(), equalTo(2));
293301
events.forEach(event -> {
294302
switch (event.name()) {
@@ -315,7 +323,7 @@ public void testSupportedStream() throws Exception {
315323

316324
var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList();
317325
try {
318-
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input);
326+
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);
319327

320328
var expectedResponses = Stream.concat(
321329
input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"),
@@ -342,7 +350,7 @@ public void testUnifiedCompletionInference() throws Exception {
342350

343351
var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList();
344352
try {
345-
var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input);
353+
var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);
346354
var expectedResponses = expectedResultsIterator(input);
347355
assertThat(events.size(), equalTo((input.size() + 1) * 2));
348356
events.forEach(event -> {

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.apache.http.nio.util.SimpleInputBuffer;
1818
import org.apache.http.protocol.HttpContext;
1919
import org.apache.http.util.EntityUtils;
20+
import org.apache.lucene.util.SetOnce;
2021
import org.elasticsearch.client.Request;
2122
import org.elasticsearch.client.RequestOptions;
2223
import org.elasticsearch.client.Response;
@@ -43,6 +44,7 @@
4344
import org.elasticsearch.rest.RestHandler;
4445
import org.elasticsearch.rest.RestRequest;
4546
import org.elasticsearch.test.ESIntegTestCase;
47+
import org.elasticsearch.threadpool.ThreadPool;
4648
import org.elasticsearch.xcontent.ToXContent;
4749
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
4850
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
@@ -52,6 +54,7 @@
5254
import java.io.IOException;
5355
import java.nio.charset.StandardCharsets;
5456
import java.util.Collection;
57+
import java.util.Collections;
5558
import java.util.Deque;
5659
import java.util.Iterator;
5760
import java.util.List;
@@ -96,6 +99,14 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
9699
}
97100

98101
public static class StreamingPlugin extends Plugin implements ActionPlugin {
102+
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();
103+
104+
@Override
105+
public Collection<?> createComponents(PluginServices services) {
106+
threadPool.set(services.threadPool());
107+
return Collections.emptyList();
108+
}
109+
99110
@Override
100111
public Collection<RestHandler> getRestHandlers(
101112
Settings settings,
@@ -122,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
122133
var publisher = new RandomPublisher(requestCount, withError);
123134
var inferenceServiceResults = new StreamingInferenceServiceResults(publisher);
124135
var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher());
125-
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
136+
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
126137
}
127138
}, new RestHandler() {
128139
@Override
@@ -132,7 +143,7 @@ public List<Route> routes() {
132143

133144
@Override
134145
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
135-
new ServerSentEventsRestActionListener(channel).onFailure(expectedException);
146+
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
136147
}
137148
}, new RestHandler() {
138149
@Override
@@ -143,7 +154,7 @@ public List<Route> routes() {
143154
@Override
144155
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
145156
var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults());
146-
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
157+
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
147158
}
148159
});
149160
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.elasticsearch.search.rank.RankDoc;
4343
import org.elasticsearch.threadpool.ExecutorBuilder;
4444
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
45+
import org.elasticsearch.threadpool.ThreadPool;
4546
import org.elasticsearch.xcontent.ParseField;
4647
import org.elasticsearch.xpack.core.ClientHelper;
4748
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
@@ -151,6 +152,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
151152
private final SetOnce<AmazonBedrockRequestSender.Factory> amazonBedrockFactory = new SetOnce<>();
152153
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();
153154
private final SetOnce<ElasticInferenceServiceComponents> eisComponents = new SetOnce<>();
155+
// This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it
156+
// not being initialized yet
157+
private final SetOnce<ThreadPool> threadPoolSetOnce = new SetOnce<>();
154158
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
155159
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
156160
private List<InferenceServiceExtension> inferenceServiceExtensions;
@@ -195,15 +199,15 @@ public List<RestHandler> getRestHandlers(
195199
) {
196200
var availableRestActions = List.of(
197201
new RestInferenceAction(),
198-
new RestStreamInferenceAction(),
202+
new RestStreamInferenceAction(threadPoolSetOnce),
199203
new RestGetInferenceModelAction(),
200204
new RestPutInferenceModelAction(),
201205
new RestUpdateInferenceModelAction(),
202206
new RestDeleteInferenceEndpointAction(),
203207
new RestGetInferenceDiagnosticsAction()
204208
);
205209
List<RestHandler> conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled()
206-
? List.of(new RestUnifiedCompletionInferenceAction())
210+
? List.of(new RestUnifiedCompletionInferenceAction(threadPoolSetOnce))
207211
: List.of();
208212

209213
return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList();
@@ -214,6 +218,7 @@ public Collection<?> createComponents(PluginServices services) {
214218
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
215219
var truncator = new Truncator(settings, services.clusterService());
216220
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
221+
threadPoolSetOnce.set(services.threadPool());
217222

218223
var httpClientManager = HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager);
219224
var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,30 @@
77

88
package org.elasticsearch.xpack.inference.rest;
99

10+
import org.apache.lucene.util.SetOnce;
1011
import org.elasticsearch.action.ActionListener;
1112
import org.elasticsearch.rest.RestChannel;
1213
import org.elasticsearch.rest.Scope;
1314
import org.elasticsearch.rest.ServerlessScope;
15+
import org.elasticsearch.threadpool.ThreadPool;
1416
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1517

1618
import java.util.List;
19+
import java.util.Objects;
1720

1821
import static org.elasticsearch.rest.RestRequest.Method.POST;
1922
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH;
2023
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_TASK_TYPE_INFERENCE_ID_PATH;
2124

2225
@ServerlessScope(Scope.PUBLIC)
2326
public class RestStreamInferenceAction extends BaseInferenceAction {
27+
private final SetOnce<ThreadPool> threadPool;
28+
29+
public RestStreamInferenceAction(SetOnce<ThreadPool> threadPool) {
30+
super();
31+
this.threadPool = Objects.requireNonNull(threadPool);
32+
}
33+
2434
@Override
2535
public String getName() {
2636
return "stream_inference_action";
@@ -38,6 +48,6 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques
3848

3949
@Override
4050
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
41-
return new ServerSentEventsRestActionListener(channel);
51+
return new ServerSentEventsRestActionListener(channel, threadPool);
4252
}
4353
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,32 @@
77

88
package org.elasticsearch.xpack.inference.rest;
99

10+
import org.apache.lucene.util.SetOnce;
1011
import org.elasticsearch.client.internal.node.NodeClient;
1112
import org.elasticsearch.rest.BaseRestHandler;
1213
import org.elasticsearch.rest.RestRequest;
1314
import org.elasticsearch.rest.Scope;
1415
import org.elasticsearch.rest.ServerlessScope;
16+
import org.elasticsearch.threadpool.ThreadPool;
1517
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
1618

1719
import java.io.IOException;
1820
import java.util.List;
21+
import java.util.Objects;
1922

2023
import static org.elasticsearch.rest.RestRequest.Method.POST;
2124
import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH;
2225
import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH;
2326

2427
@ServerlessScope(Scope.PUBLIC)
2528
public class RestUnifiedCompletionInferenceAction extends BaseRestHandler {
29+
private final SetOnce<ThreadPool> threadPool;
30+
31+
public RestUnifiedCompletionInferenceAction(SetOnce<ThreadPool> threadPool) {
32+
super();
33+
this.threadPool = Objects.requireNonNull(threadPool);
34+
}
35+
2636
@Override
2737
public String getName() {
2838
return "unified_inference_action";
@@ -44,6 +54,10 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
4454
request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser);
4555
}
4656

47-
return channel -> client.execute(UnifiedCompletionAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel));
57+
return channel -> client.execute(
58+
UnifiedCompletionAction.INSTANCE,
59+
request,
60+
new ServerSentEventsRestActionListener(channel, threadPool)
61+
);
4862
}
4963
}

0 commit comments

Comments
 (0)