Skip to content

Commit b0f2b4f

Browse files
Adding context preserving fix
1 parent 1004da2 commit b0f2b4f

File tree

7 files changed

+61
-13
lines changed

7 files changed

+61
-13
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.apache.http.nio.util.SimpleInputBuffer;
1818
import org.apache.http.protocol.HttpContext;
1919
import org.apache.http.util.EntityUtils;
20+
import org.apache.lucene.util.SetOnce;
2021
import org.elasticsearch.client.Request;
2122
import org.elasticsearch.client.RequestOptions;
2223
import org.elasticsearch.client.Response;
@@ -43,6 +44,7 @@
4344
import org.elasticsearch.rest.RestHandler;
4445
import org.elasticsearch.rest.RestRequest;
4546
import org.elasticsearch.test.ESIntegTestCase;
47+
import org.elasticsearch.threadpool.ThreadPool;
4648
import org.elasticsearch.xcontent.ToXContent;
4749
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
4850
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
@@ -52,6 +54,7 @@
5254
import java.io.IOException;
5355
import java.nio.charset.StandardCharsets;
5456
import java.util.Collection;
57+
import java.util.Collections;
5558
import java.util.Deque;
5659
import java.util.Iterator;
5760
import java.util.List;
@@ -96,6 +99,14 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
9699
}
97100

98101
public static class StreamingPlugin extends Plugin implements ActionPlugin {
102+
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();
103+
104+
@Override
105+
public Collection<?> createComponents(PluginServices services) {
106+
threadPool.set(services.threadPool());
107+
return Collections.emptyList();
108+
}
109+
99110
@Override
100111
public Collection<RestHandler> getRestHandlers(
101112
Settings settings,
@@ -122,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
122133
var publisher = new RandomPublisher(requestCount, withError);
123134
var inferenceServiceResults = new StreamingInferenceServiceResults(publisher);
124135
var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher());
125-
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
136+
new ServerSentEventsRestActionListener(channel, threadPool.get()).onResponse(inferenceResponse);
126137
}
127138
}, new RestHandler() {
128139
@Override
@@ -132,7 +143,7 @@ public List<Route> routes() {
132143

133144
@Override
134145
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
135-
new ServerSentEventsRestActionListener(channel).onFailure(expectedException);
146+
new ServerSentEventsRestActionListener(channel, threadPool.get()).onFailure(expectedException);
136147
}
137148
}, new RestHandler() {
138149
@Override
@@ -143,7 +154,7 @@ public List<Route> routes() {
143154
@Override
144155
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
145156
var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults());
146-
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
157+
new ServerSentEventsRestActionListener(channel, threadPool.get()).onResponse(inferenceResponse);
147158
}
148159
});
149160
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,11 @@ public List<RestHandler> getRestHandlers(
197197
Supplier<DiscoveryNodes> nodesInCluster,
198198
Predicate<NodeFeature> clusterSupportsFeature
199199
) {
200+
assert serviceComponents.get() != null : "serviceComponents must be set before retrieving the rest handlers";
201+
200202
var availableRestActions = List.of(
201203
new RestInferenceAction(),
202-
new RestStreamInferenceAction(),
204+
new RestStreamInferenceAction(serviceComponents.get().threadPool()),
203205
new RestGetInferenceModelAction(),
204206
new RestPutInferenceModelAction(),
205207
new RestUpdateInferenceModelAction(),
@@ -208,7 +210,7 @@ public List<RestHandler> getRestHandlers(
208210
new RestGetInferenceServicesAction()
209211
);
210212
List<RestHandler> conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled()
211-
? List.of(new RestUnifiedCompletionInferenceAction())
213+
? List.of(new RestUnifiedCompletionInferenceAction(serviceComponents.get().threadPool()))
212214
: List.of();
213215

214216
return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,25 @@
1111
import org.elasticsearch.rest.RestChannel;
1212
import org.elasticsearch.rest.Scope;
1313
import org.elasticsearch.rest.ServerlessScope;
14+
import org.elasticsearch.threadpool.ThreadPool;
1415
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1516

1617
import java.util.List;
18+
import java.util.Objects;
1719

1820
import static org.elasticsearch.rest.RestRequest.Method.POST;
1921
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH;
2022
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_TASK_TYPE_INFERENCE_ID_PATH;
2123

2224
@ServerlessScope(Scope.PUBLIC)
2325
public class RestStreamInferenceAction extends BaseInferenceAction {
26+
private final ThreadPool threadPool;
27+
28+
public RestStreamInferenceAction(ThreadPool threadPool) {
29+
super();
30+
this.threadPool = Objects.requireNonNull(threadPool);
31+
}
32+
2433
@Override
2534
public String getName() {
2635
return "stream_inference_action";
@@ -38,6 +47,6 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques
3847

3948
@Override
4049
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
41-
return new ServerSentEventsRestActionListener(channel);
50+
return new ServerSentEventsRestActionListener(channel, threadPool);
4251
}
4352
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,26 @@
1212
import org.elasticsearch.rest.RestRequest;
1313
import org.elasticsearch.rest.Scope;
1414
import org.elasticsearch.rest.ServerlessScope;
15+
import org.elasticsearch.threadpool.ThreadPool;
1516
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
1617

1718
import java.io.IOException;
1819
import java.util.List;
20+
import java.util.Objects;
1921

2022
import static org.elasticsearch.rest.RestRequest.Method.POST;
2123
import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH;
2224
import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH;
2325

2426
@ServerlessScope(Scope.PUBLIC)
2527
public class RestUnifiedCompletionInferenceAction extends BaseRestHandler {
28+
private final ThreadPool threadPool;
29+
30+
public RestUnifiedCompletionInferenceAction(ThreadPool threadPool) {
31+
super();
32+
this.threadPool = Objects.requireNonNull(threadPool);
33+
}
34+
2635
@Override
2736
public String getName() {
2837
return "unified_inference_action";
@@ -44,6 +53,10 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
4453
request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser);
4554
}
4655

47-
return channel -> client.execute(UnifiedCompletionAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel));
56+
return channel -> client.execute(
57+
UnifiedCompletionAction.INSTANCE,
58+
request,
59+
new ServerSentEventsRestActionListener(channel, threadPool)
60+
);
4861
}
4962
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.ElasticsearchException;
1414
import org.elasticsearch.ExceptionsHelper;
1515
import org.elasticsearch.action.ActionListener;
16+
import org.elasticsearch.action.support.ContextPreservingActionListener;
1617
import org.elasticsearch.common.bytes.ReleasableBytesReference;
1718
import org.elasticsearch.common.collect.Iterators;
1819
import org.elasticsearch.common.io.stream.BytesStream;
@@ -29,6 +30,7 @@
2930
import org.elasticsearch.rest.RestResponse;
3031
import org.elasticsearch.rest.RestStatus;
3132
import org.elasticsearch.tasks.TaskCancelledException;
33+
import org.elasticsearch.threadpool.ThreadPool;
3234
import org.elasticsearch.xcontent.ToXContent;
3335
import org.elasticsearch.xcontent.XContentBuilder;
3436
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
@@ -38,6 +40,7 @@
3840
import java.nio.charset.StandardCharsets;
3941
import java.util.Iterator;
4042
import java.util.Map;
43+
import java.util.Objects;
4144
import java.util.concurrent.Flow;
4245
import java.util.concurrent.atomic.AtomicBoolean;
4346

@@ -55,6 +58,7 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
5558
private final AtomicBoolean isLastPart = new AtomicBoolean(false);
5659
private final RestChannel channel;
5760
private final ToXContent.Params params;
61+
private final ThreadPool threadPool;
5862

5963
/**
6064
* A listener for the first part of the next entry to become available for transmission.
@@ -66,13 +70,14 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
6670
*/
6771
private ActionListener<ChunkedRestResponseBodyPart> nextBodyPartListener;
6872

69-
public ServerSentEventsRestActionListener(RestChannel channel) {
70-
this(channel, channel.request());
73+
public ServerSentEventsRestActionListener(RestChannel channel, ThreadPool threadPool) {
74+
this(channel, channel.request(), threadPool);
7175
}
7276

73-
public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params) {
77+
public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, ThreadPool threadPool) {
7478
this.channel = channel;
7579
this.params = params;
80+
this.threadPool = Objects.requireNonNull(threadPool);
7681
}
7782

7883
@Override
@@ -99,7 +104,7 @@ protected void ensureOpen() {
99104
}
100105

101106
private void initializeStream(InferenceAction.Response response) {
102-
nextBodyPartListener = ActionListener.wrap(bodyPart -> {
107+
ActionListener<ChunkedRestResponseBodyPart> chunkedResponseBodyActionListener = ActionListener.wrap(bodyPart -> {
103108
// this is the first response, so we need to send the RestResponse to open the stream
104109
// all subsequent bytes will be delivered through the nextBodyPartListener
105110
channel.sendResponse(RestResponse.chunked(RestStatus.OK, bodyPart, this::release));
@@ -115,6 +120,12 @@ private void initializeStream(InferenceAction.Response response) {
115120
)
116121
);
117122
});
123+
124+
nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext(
125+
chunkedResponseBodyActionListener,
126+
threadPool.getThreadContext()
127+
);
128+
118129
// subscribe will call onSubscribe, which requests the first chunk
119130
response.publisher().subscribe(subscriber);
120131
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.rest.RestRequest;
1313
import org.elasticsearch.test.rest.FakeRestRequest;
1414
import org.elasticsearch.test.rest.RestActionTestCase;
15+
import org.elasticsearch.threadpool.TestThreadPool;
1516
import org.elasticsearch.xcontent.XContentType;
1617
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1718
import org.junit.Before;
@@ -25,7 +26,7 @@ public class RestStreamInferenceActionTests extends RestActionTestCase {
2526

2627
@Before
2728
public void setUpAction() {
28-
controller().registerHandler(new RestStreamInferenceAction());
29+
controller().registerHandler(new RestStreamInferenceAction(new TestThreadPool(getTestName())));
2930
}
3031

3132
public void testStreamIsTrue() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.rest.RestResponse;
1818
import org.elasticsearch.test.rest.FakeRestRequest;
1919
import org.elasticsearch.test.rest.RestActionTestCase;
20+
import org.elasticsearch.threadpool.TestThreadPool;
2021
import org.elasticsearch.xcontent.XContentType;
2122
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
2223
import org.junit.Before;
@@ -30,7 +31,7 @@ public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCas
3031

3132
@Before
3233
public void setUpAction() {
33-
controller().registerHandler(new RestUnifiedCompletionInferenceAction());
34+
controller().registerHandler(new RestUnifiedCompletionInferenceAction(new TestThreadPool(getTestName())));
3435
}
3536

3637
public void testStreamIsTrue() {

0 commit comments

Comments
 (0)