Skip to content

Commit cb42fd4

Browse files
[ML] Stream Inference API (elastic#113158) (elastic#113423)
Create `POST _inference/<task>/<id>/_stream` and `POST _inference/<id>/_stream` API. REST Streaming API will reuse InferenceAction. For now, all services and task types will return an HTTP 405 status code and error message. Co-authored-by: Elastic Machine <[email protected]>
1 parent 9a21ca6 commit cb42fd4

File tree

24 files changed

+798
-121
lines changed

24 files changed

+798
-121
lines changed

docs/changelog/113158.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 113158
2+
summary: Adds a new Inference API for streaming responses back to the user.
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,21 @@ default boolean isInClusterService() {
188188
* @return {@link TransportVersion} specifying the version
189189
*/
190190
TransportVersion getMinimalSupportedVersion();
191+
192+
/**
193+
* The set of tasks where this service provider supports using the streaming API.
194+
* @return set of supported task types. Defaults to empty.
195+
*/
196+
default Set<TaskType> supportedStreamingTasks() {
197+
return Set.of();
198+
}
199+
200+
/**
201+
* Checks the task type against the set of supported streaming tasks returned by {@link #supportedStreamingTasks()}.
202+
* @param taskType the task that supports streaming
203+
* @return true if the taskType is supported
204+
*/
205+
default boolean canStream(TaskType taskType) {
206+
return supportedStreamingTasks().contains(taskType);
207+
}
191208
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
9292
private final Map<String, Object> taskSettings;
9393
private final InputType inputType;
9494
private final TimeValue inferenceTimeout;
95+
private final boolean stream;
9596

9697
public Request(
9798
TaskType taskType,
@@ -100,7 +101,8 @@ public Request(
100101
List<String> input,
101102
Map<String, Object> taskSettings,
102103
InputType inputType,
103-
TimeValue inferenceTimeout
104+
TimeValue inferenceTimeout,
105+
boolean stream
104106
) {
105107
this.taskType = taskType;
106108
this.inferenceEntityId = inferenceEntityId;
@@ -109,6 +111,7 @@ public Request(
109111
this.taskSettings = taskSettings;
110112
this.inputType = inputType;
111113
this.inferenceTimeout = inferenceTimeout;
114+
this.stream = stream;
112115
}
113116

114117
public Request(StreamInput in) throws IOException {
@@ -134,6 +137,9 @@ public Request(StreamInput in) throws IOException {
134137
this.query = null;
135138
this.inferenceTimeout = DEFAULT_TIMEOUT;
136139
}
140+
141+
// streaming is not supported yet for transport traffic
142+
this.stream = false;
137143
}
138144

139145
public TaskType getTaskType() {
@@ -165,7 +171,7 @@ public TimeValue getInferenceTimeout() {
165171
}
166172

167173
public boolean isStreaming() {
168-
return false;
174+
return stream;
169175
}
170176

171177
@Override
@@ -261,6 +267,7 @@ public static class Builder {
261267
private Map<String, Object> taskSettings = Map.of();
262268
private String query;
263269
private TimeValue timeout = DEFAULT_TIMEOUT;
270+
private boolean stream = false;
264271

265272
private Builder() {}
266273

@@ -303,8 +310,13 @@ private Builder setInferenceTimeout(String inferenceTimeout) {
303310
return setInferenceTimeout(TimeValue.parseTimeValue(inferenceTimeout, TIMEOUT.getPreferredName()));
304311
}
305312

313+
public Builder setStream(boolean stream) {
314+
this.stream = stream;
315+
return this;
316+
}
317+
306318
public Request build() {
307-
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout);
319+
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
308320
}
309321
}
310322

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ protected InferenceAction.Request createTestInstance() {
4646
randomList(1, 5, () -> randomAlphaOfLength(8)),
4747
randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
4848
randomFrom(InputType.values()),
49-
TimeValue.timeValueMillis(randomLongBetween(1, 2048))
49+
TimeValue.timeValueMillis(randomLongBetween(1, 2048)),
50+
false
5051
);
5152
}
5253

@@ -80,7 +81,8 @@ public void testValidation_TextEmbedding() {
8081
List.of("input"),
8182
null,
8283
null,
83-
null
84+
null,
85+
false
8486
);
8587
ActionRequestValidationException e = request.validate();
8688
assertNull(e);
@@ -94,7 +96,8 @@ public void testValidation_Rerank() {
9496
List.of("input"),
9597
null,
9698
null,
97-
null
99+
null,
100+
false
98101
);
99102
ActionRequestValidationException e = request.validate();
100103
assertNull(e);
@@ -108,7 +111,8 @@ public void testValidation_TextEmbedding_Null() {
108111
null,
109112
null,
110113
null,
111-
null
114+
null,
115+
false
112116
);
113117
ActionRequestValidationException inputNullError = inputNullRequest.validate();
114118
assertNotNull(inputNullError);
@@ -123,7 +127,8 @@ public void testValidation_TextEmbedding_Empty() {
123127
List.of(),
124128
null,
125129
null,
126-
null
130+
null,
131+
false
127132
);
128133
ActionRequestValidationException inputEmptyError = inputEmptyRequest.validate();
129134
assertNotNull(inputEmptyError);
@@ -138,7 +143,8 @@ public void testValidation_Rerank_Null() {
138143
List.of("input"),
139144
null,
140145
null,
141-
null
146+
null,
147+
false
142148
);
143149
ActionRequestValidationException queryNullError = queryNullRequest.validate();
144150
assertNotNull(queryNullError);
@@ -153,7 +159,8 @@ public void testValidation_Rerank_Empty() {
153159
List.of("input"),
154160
null,
155161
null,
156-
null
162+
null,
163+
false
157164
);
158165
ActionRequestValidationException queryEmptyError = queryEmptyRequest.validate();
159166
assertNotNull(queryEmptyError);
@@ -185,7 +192,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
185192
instance.getInput(),
186193
instance.getTaskSettings(),
187194
instance.getInputType(),
188-
instance.getInferenceTimeout()
195+
instance.getInferenceTimeout(),
196+
false
189197
);
190198
}
191199
case 1 -> new InferenceAction.Request(
@@ -195,7 +203,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
195203
instance.getInput(),
196204
instance.getTaskSettings(),
197205
instance.getInputType(),
198-
instance.getInferenceTimeout()
206+
instance.getInferenceTimeout(),
207+
false
199208
);
200209
case 2 -> {
201210
var changedInputs = new ArrayList<String>(instance.getInput());
@@ -207,7 +216,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
207216
changedInputs,
208217
instance.getTaskSettings(),
209218
instance.getInputType(),
210-
instance.getInferenceTimeout()
219+
instance.getInferenceTimeout(),
220+
false
211221
);
212222
}
213223
case 3 -> {
@@ -225,7 +235,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
225235
instance.getInput(),
226236
taskSettings,
227237
instance.getInputType(),
228-
instance.getInferenceTimeout()
238+
instance.getInferenceTimeout(),
239+
false
229240
);
230241
}
231242
case 4 -> {
@@ -237,7 +248,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
237248
instance.getInput(),
238249
instance.getTaskSettings(),
239250
nextInputType,
240-
instance.getInferenceTimeout()
251+
instance.getInferenceTimeout(),
252+
false
241253
);
242254
}
243255
case 5 -> new InferenceAction.Request(
@@ -247,7 +259,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
247259
instance.getInput(),
248260
instance.getTaskSettings(),
249261
instance.getInputType(),
250-
instance.getInferenceTimeout()
262+
instance.getInferenceTimeout(),
263+
false
251264
);
252265
case 6 -> {
253266
var newDuration = Duration.of(
@@ -262,7 +275,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
262275
instance.getInput(),
263276
instance.getTaskSettings(),
264277
instance.getInputType(),
265-
TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis())
278+
TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()),
279+
false
266280
);
267281
}
268282
default -> throw new UnsupportedOperationException();
@@ -279,7 +293,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
279293
instance.getInput().subList(0, 1),
280294
instance.getTaskSettings(),
281295
InputType.UNSPECIFIED,
282-
InferenceAction.Request.DEFAULT_TIMEOUT
296+
InferenceAction.Request.DEFAULT_TIMEOUT,
297+
false
283298
);
284299
} else if (version.before(TransportVersions.V_8_13_0)) {
285300
return new InferenceAction.Request(
@@ -289,7 +304,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
289304
instance.getInput(),
290305
instance.getTaskSettings(),
291306
InputType.UNSPECIFIED,
292-
InferenceAction.Request.DEFAULT_TIMEOUT
307+
InferenceAction.Request.DEFAULT_TIMEOUT,
308+
false
293309
);
294310
} else if (version.before(TransportVersions.V_8_13_0)
295311
&& (instance.getInputType() == InputType.UNSPECIFIED
@@ -302,7 +318,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
302318
instance.getInput(),
303319
instance.getTaskSettings(),
304320
InputType.INGEST,
305-
InferenceAction.Request.DEFAULT_TIMEOUT
321+
InferenceAction.Request.DEFAULT_TIMEOUT,
322+
false
306323
);
307324
} else if (version.before(TransportVersions.V_8_13_0)
308325
&& (instance.getInputType() == InputType.CLUSTERING || instance.getInputType() == InputType.CLASSIFICATION)) {
@@ -313,7 +330,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
313330
instance.getInput(),
314331
instance.getTaskSettings(),
315332
InputType.UNSPECIFIED,
316-
InferenceAction.Request.DEFAULT_TIMEOUT
333+
InferenceAction.Request.DEFAULT_TIMEOUT,
334+
false
317335
);
318336
} else if (version.before(TransportVersions.V_8_14_0)) {
319337
return new InferenceAction.Request(
@@ -323,7 +341,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
323341
instance.getInput(),
324342
instance.getTaskSettings(),
325343
instance.getInputType(),
326-
InferenceAction.Request.DEFAULT_TIMEOUT
344+
InferenceAction.Request.DEFAULT_TIMEOUT,
345+
false
327346
);
328347
}
329348

@@ -339,7 +358,8 @@ public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOExceptio
339358
List.of(),
340359
Map.of(),
341360
InputType.UNSPECIFIED,
342-
InferenceAction.Request.DEFAULT_TIMEOUT
361+
InferenceAction.Request.DEFAULT_TIMEOUT,
362+
false
343363
),
344364
TransportVersions.V_8_13_0
345365
);
@@ -353,7 +373,8 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn
353373
List.of(),
354374
Map.of(),
355375
InputType.INGEST,
356-
InferenceAction.Request.DEFAULT_TIMEOUT
376+
InferenceAction.Request.DEFAULT_TIMEOUT,
377+
false
357378
);
358379

359380
InferenceAction.Request deserializedInstance = copyWriteable(
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.apache.http.HttpEntity;
11+
import org.apache.http.HttpResponse;
12+
import org.apache.http.entity.ContentType;
13+
import org.apache.http.nio.ContentDecoder;
14+
import org.apache.http.nio.IOControl;
15+
import org.apache.http.nio.protocol.AbstractAsyncResponseConsumer;
16+
import org.apache.http.nio.util.SimpleInputBuffer;
17+
import org.apache.http.protocol.HttpContext;
18+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
19+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
20+
21+
import java.io.IOException;
22+
import java.util.ArrayDeque;
23+
import java.util.Deque;
24+
import java.util.concurrent.atomic.AtomicReference;
25+
26+
class AsyncInferenceResponseConsumer extends AbstractAsyncResponseConsumer<HttpResponse> {
27+
private final AtomicReference<HttpResponse> httpResponse = new AtomicReference<>();
28+
private final Deque<ServerSentEvent> collector = new ArrayDeque<>();
29+
private final ServerSentEventParser sseParser = new ServerSentEventParser();
30+
private final SimpleInputBuffer inputBuffer = new SimpleInputBuffer(4096);
31+
32+
@Override
33+
protected void onResponseReceived(HttpResponse httpResponse) {
34+
this.httpResponse.set(httpResponse);
35+
}
36+
37+
@Override
38+
protected void onContentReceived(ContentDecoder contentDecoder, IOControl ioControl) throws IOException {
39+
inputBuffer.consumeContent(contentDecoder);
40+
}
41+
42+
@Override
43+
protected void onEntityEnclosed(HttpEntity httpEntity, ContentType contentType) {
44+
httpResponse.updateAndGet(response -> {
45+
response.setEntity(httpEntity);
46+
return response;
47+
});
48+
}
49+
50+
@Override
51+
protected HttpResponse buildResult(HttpContext httpContext) {
52+
var allBytes = new byte[inputBuffer.length()];
53+
try {
54+
inputBuffer.read(allBytes);
55+
sseParser.parse(allBytes).forEach(collector::offer);
56+
} catch (IOException e) {
57+
failed(e);
58+
}
59+
return httpResponse.get();
60+
}
61+
62+
@Override
63+
protected void releaseResources() {}
64+
65+
Deque<ServerSentEvent> events() {
66+
return collector;
67+
}
68+
}

0 commit comments

Comments
 (0)