Skip to content

Commit 69bbd54

Browse files
[ML] Fix loss of context in the inference API for streaming APIs (#118999) (#119222)
* Adding context preserving fix * Update docs/changelog/118999.yaml * Update docs/changelog/118999.yaml * Using a setonce and adding a test * Updating the changelog (cherry picked from commit 7ba3cb9) # Conflicts: # x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java # x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java
1 parent c65e727 commit 69bbd54

File tree

8 files changed

+94
-16
lines changed

8 files changed

+94
-16
lines changed

docs/changelog/118999.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 118999
2+
summary: Fix loss of context in the inference API for streaming APIs
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 119000

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.Map;
3131
import java.util.concurrent.CountDownLatch;
3232
import java.util.concurrent.TimeUnit;
33+
import java.util.function.Consumer;
3334

3435
import static org.hamcrest.Matchers.anyOf;
3536
import static org.hamcrest.Matchers.equalTo;
@@ -336,20 +337,34 @@ protected Map<String, Object> infer(String modelId, List<String> input) throws I
336337
return inferInternal(endpoint, input, Map.of());
337338
}
338339

339-
protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
340+
protected Deque<ServerSentEvent> streamInferOnMockService(
341+
String modelId,
342+
TaskType taskType,
343+
List<String> input,
344+
@Nullable Consumer<Response> responseConsumerCallback
345+
) throws Exception {
340346
var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId);
341-
return callAsync(endpoint, input);
347+
return callAsync(endpoint, input, responseConsumerCallback);
342348
}
343349

344-
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
345-
var responseConsumer = new AsyncInferenceResponseConsumer();
350+
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input, @Nullable Consumer<Response> responseConsumerCallback)
351+
throws Exception {
346352
var request = new Request("POST", endpoint);
347353
request.setJsonEntity(jsonBody(input));
354+
355+
return execAsyncCall(request, responseConsumerCallback);
356+
}
357+
358+
private Deque<ServerSentEvent> execAsyncCall(Request request, @Nullable Consumer<Response> responseConsumerCallback) throws Exception {
359+
var responseConsumer = new AsyncInferenceResponseConsumer();
348360
request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
349361
var latch = new CountDownLatch(1);
350362
client().performRequestAsync(request, new ResponseListener() {
351363
@Override
352364
public void onSuccess(Response response) {
365+
if (responseConsumerCallback != null) {
366+
responseConsumerCallback.accept(response);
367+
}
353368
latch.countDown();
354369
}
355370

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.xpack.inference;
1111

1212
import org.apache.http.util.EntityUtils;
13+
import org.elasticsearch.client.Response;
1314
import org.elasticsearch.client.ResponseException;
1415
import org.elasticsearch.common.settings.Settings;
1516
import org.elasticsearch.inference.TaskType;
@@ -19,6 +20,7 @@
1920
import java.util.Map;
2021
import java.util.Objects;
2122
import java.util.Set;
23+
import java.util.function.Consumer;
2224
import java.util.function.Function;
2325
import java.util.stream.IntStream;
2426
import java.util.stream.Stream;
@@ -28,9 +30,15 @@
2830
import static org.hamcrest.Matchers.equalTo;
2931
import static org.hamcrest.Matchers.equalToIgnoringCase;
3032
import static org.hamcrest.Matchers.hasSize;
33+
import static org.hamcrest.Matchers.is;
3134

3235
public class InferenceCrudIT extends InferenceBaseRestTest {
3336

37+
private static final Consumer<Response> VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER = (r) -> assertThat(
38+
r.getHeader("X-elastic-product"),
39+
is("Elasticsearch")
40+
);
41+
3442
@SuppressWarnings("unchecked")
3543
public void testCRUD() throws IOException {
3644
for (int i = 0; i < 5; i++) {
@@ -282,7 +290,7 @@ public void testUnsupportedStream() throws Exception {
282290
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));
283291

284292
try {
285-
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)));
293+
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)), null);
286294
assertThat(events.size(), equalTo(2));
287295
events.forEach(event -> {
288296
switch (event.name()) {
@@ -309,7 +317,7 @@ public void testSupportedStream() throws Exception {
309317

310318
var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphaOfLength(10)).toList();
311319
try {
312-
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input);
320+
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);
313321

314322
var expectedResponses = Stream.concat(
315323
input.stream().map(String::toUpperCase).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"),

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.apache.http.nio.util.SimpleInputBuffer;
1818
import org.apache.http.protocol.HttpContext;
1919
import org.apache.http.util.EntityUtils;
20+
import org.apache.lucene.util.SetOnce;
2021
import org.elasticsearch.client.Request;
2122
import org.elasticsearch.client.RequestOptions;
2223
import org.elasticsearch.client.Response;
@@ -43,6 +44,7 @@
4344
import org.elasticsearch.rest.RestHandler;
4445
import org.elasticsearch.rest.RestRequest;
4546
import org.elasticsearch.test.ESIntegTestCase;
47+
import org.elasticsearch.threadpool.ThreadPool;
4648
import org.elasticsearch.xcontent.ToXContent;
4749
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
4850
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
@@ -52,6 +54,7 @@
5254
import java.io.IOException;
5355
import java.nio.charset.StandardCharsets;
5456
import java.util.Collection;
57+
import java.util.Collections;
5558
import java.util.Deque;
5659
import java.util.Iterator;
5760
import java.util.List;
@@ -96,6 +99,14 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
9699
}
97100

98101
public static class StreamingPlugin extends Plugin implements ActionPlugin {
102+
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();
103+
104+
@Override
105+
public Collection<?> createComponents(PluginServices services) {
106+
threadPool.set(services.threadPool());
107+
return Collections.emptyList();
108+
}
109+
99110
@Override
100111
public Collection<RestHandler> getRestHandlers(
101112
Settings settings,
@@ -122,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
122133
var publisher = new RandomPublisher(requestCount, withError);
123134
var inferenceServiceResults = new StreamingInferenceServiceResults(publisher);
124135
var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher());
125-
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
136+
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
126137
}
127138
}, new RestHandler() {
128139
@Override
@@ -132,7 +143,7 @@ public List<Route> routes() {
132143

133144
@Override
134145
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
135-
new ServerSentEventsRestActionListener(channel).onFailure(expectedException);
146+
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
136147
}
137148
}, new RestHandler() {
138149
@Override
@@ -143,7 +154,7 @@ public List<Route> routes() {
143154
@Override
144155
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
145156
var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults());
146-
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
157+
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
147158
}
148159
});
149160
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.elasticsearch.search.rank.RankDoc;
4040
import org.elasticsearch.threadpool.ExecutorBuilder;
4141
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
42+
import org.elasticsearch.threadpool.ThreadPool;
4243
import org.elasticsearch.xcontent.ParseField;
4344
import org.elasticsearch.xpack.core.ClientHelper;
4445
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
@@ -140,6 +141,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
140141
private final SetOnce<AmazonBedrockRequestSender.Factory> amazonBedrockFactory = new SetOnce<>();
141142
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();
142143
private final SetOnce<ElasticInferenceServiceComponents> eisComponents = new SetOnce<>();
144+
// This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it
145+
// not being initialized yet
146+
private final SetOnce<ThreadPool> threadPoolSetOnce = new SetOnce<>();
143147
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
144148
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
145149
private List<InferenceServiceExtension> inferenceServiceExtensions;
@@ -176,7 +180,7 @@ public List<RestHandler> getRestHandlers(
176180
) {
177181
return List.of(
178182
new RestInferenceAction(),
179-
new RestStreamInferenceAction(),
183+
new RestStreamInferenceAction(threadPoolSetOnce),
180184
new RestGetInferenceModelAction(),
181185
new RestPutInferenceModelAction(),
182186
new RestUpdateInferenceModelAction(),
@@ -190,6 +194,7 @@ public Collection<?> createComponents(PluginServices services) {
190194
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
191195
var truncator = new Truncator(settings, services.clusterService());
192196
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
197+
threadPoolSetOnce.set(services.threadPool());
193198

194199
var httpClientManager = HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager);
195200
var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,30 @@
77

88
package org.elasticsearch.xpack.inference.rest;
99

10+
import org.apache.lucene.util.SetOnce;
1011
import org.elasticsearch.action.ActionListener;
1112
import org.elasticsearch.rest.RestChannel;
1213
import org.elasticsearch.rest.Scope;
1314
import org.elasticsearch.rest.ServerlessScope;
15+
import org.elasticsearch.threadpool.ThreadPool;
1416
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1517

1618
import java.util.List;
19+
import java.util.Objects;
1720

1821
import static org.elasticsearch.rest.RestRequest.Method.POST;
1922
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH;
2023
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_TASK_TYPE_INFERENCE_ID_PATH;
2124

2225
@ServerlessScope(Scope.PUBLIC)
2326
public class RestStreamInferenceAction extends BaseInferenceAction {
27+
private final SetOnce<ThreadPool> threadPool;
28+
29+
public RestStreamInferenceAction(SetOnce<ThreadPool> threadPool) {
30+
super();
31+
this.threadPool = Objects.requireNonNull(threadPool);
32+
}
33+
2434
@Override
2535
public String getName() {
2636
return "stream_inference_action";
@@ -38,6 +48,6 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques
3848

3949
@Override
4050
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
41-
return new ServerSentEventsRestActionListener(channel);
51+
return new ServerSentEventsRestActionListener(channel, threadPool);
4252
}
4353
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.lucene.util.BytesRef;
13+
import org.apache.lucene.util.SetOnce;
1314
import org.elasticsearch.ElasticsearchException;
1415
import org.elasticsearch.ExceptionsHelper;
1516
import org.elasticsearch.action.ActionListener;
17+
import org.elasticsearch.action.support.ContextPreservingActionListener;
1618
import org.elasticsearch.common.bytes.ReleasableBytesReference;
1719
import org.elasticsearch.common.collect.Iterators;
1820
import org.elasticsearch.common.io.stream.BytesStream;
@@ -30,6 +32,7 @@
3032
import org.elasticsearch.rest.RestResponse;
3133
import org.elasticsearch.rest.RestStatus;
3234
import org.elasticsearch.tasks.TaskCancelledException;
35+
import org.elasticsearch.threadpool.ThreadPool;
3336
import org.elasticsearch.xcontent.ToXContent;
3437
import org.elasticsearch.xcontent.XContentBuilder;
3538
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
@@ -39,6 +42,7 @@
3942
import java.nio.charset.StandardCharsets;
4043
import java.util.Iterator;
4144
import java.util.Map;
45+
import java.util.Objects;
4246
import java.util.concurrent.Flow;
4347
import java.util.concurrent.atomic.AtomicBoolean;
4448

@@ -55,6 +59,7 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
5559
private final AtomicBoolean isLastPart = new AtomicBoolean(false);
5660
private final RestChannel channel;
5761
private final ToXContent.Params params;
62+
private final SetOnce<ThreadPool> threadPool;
5863

5964
/**
6065
* A listener for the first part of the next entry to become available for transmission.
@@ -66,13 +71,14 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
6671
*/
6772
private ActionListener<ChunkedRestResponseBodyPart> nextBodyPartListener;
6873

69-
public ServerSentEventsRestActionListener(RestChannel channel) {
70-
this(channel, channel.request());
74+
public ServerSentEventsRestActionListener(RestChannel channel, SetOnce<ThreadPool> threadPool) {
75+
this(channel, channel.request(), threadPool);
7176
}
7277

73-
public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params) {
78+
public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, SetOnce<ThreadPool> threadPool) {
7479
this.channel = channel;
7580
this.params = params;
81+
this.threadPool = Objects.requireNonNull(threadPool);
7682
}
7783

7884
@Override
@@ -99,7 +105,7 @@ protected void ensureOpen() {
99105
}
100106

101107
private void initializeStream(InferenceAction.Response response) {
102-
nextBodyPartListener = ActionListener.wrap(bodyPart -> {
108+
ActionListener<ChunkedRestResponseBodyPart> chunkedResponseBodyActionListener = ActionListener.wrap(bodyPart -> {
103109
// this is the first response, so we need to send the RestResponse to open the stream
104110
// all subsequent bytes will be delivered through the nextBodyPartListener
105111
channel.sendResponse(RestResponse.chunked(RestStatus.OK, bodyPart, this::release));
@@ -115,6 +121,12 @@ private void initializeStream(InferenceAction.Response response) {
115121
)
116122
);
117123
});
124+
125+
nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext(
126+
chunkedResponseBodyActionListener,
127+
threadPool.get().getThreadContext()
128+
);
129+
118130
// subscribe will call onSubscribe, which requests the first chunk
119131
response.publisher().subscribe(subscriber);
120132
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
import org.elasticsearch.rest.RestRequest;
1313
import org.elasticsearch.test.rest.FakeRestRequest;
1414
import org.elasticsearch.test.rest.RestActionTestCase;
15+
import org.elasticsearch.threadpool.TestThreadPool;
16+
import org.elasticsearch.threadpool.ThreadPool;
1517
import org.elasticsearch.xcontent.XContentType;
1618
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
19+
import org.junit.After;
1720
import org.junit.Before;
1821

1922
import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse;
@@ -22,10 +25,18 @@
2225
import static org.hamcrest.Matchers.instanceOf;
2326

2427
public class RestStreamInferenceActionTests extends RestActionTestCase {
28+
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();
2529

2630
@Before
2731
public void setUpAction() {
28-
controller().registerHandler(new RestStreamInferenceAction());
32+
threadPool.set(new TestThreadPool(getTestName()));
33+
controller().registerHandler(new RestStreamInferenceAction(threadPool));
34+
}
35+
36+
@After
37+
public void tearDownAction() {
38+
terminate(threadPool.get());
39+
2940
}
3041

3142
public void testStreamIsTrue() {

0 commit comments

Comments
 (0)