Skip to content

Commit 1e2c19f

Browse files
[ML] Add stream flag to inference providers (elastic#113424)
Pass the stream flag from the REST request through to the inference providers via the InferenceInputs. Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
1 parent d8a3215 commit 1e2c19f

File tree

29 files changed

+91
-24
lines changed

29 files changed

+91
-24
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ void parseRequestConfig(
8585
* @param model The model
8686
* @param query Inference query, mainly for re-ranking
8787
* @param input Inference input
88+
* @param stream Stream inference results
8889
* @param taskSettings Settings in the request to override the model's defaults
8990
* @param inputType For search, ingest etc
9091
* @param timeout The timeout for the request
@@ -94,6 +95,7 @@ void infer(
9495
Model model,
9596
@Nullable String query,
9697
List<String> input,
98+
boolean stream,
9799
Map<String, Object> taskSettings,
98100
InputType inputType,
99101
TimeValue timeout,

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ public void infer(
9494
Model model,
9595
@Nullable String query,
9696
List<String> input,
97+
boolean stream,
9798
Map<String, Object> taskSettings,
9899
InputType inputType,
99100
TimeValue timeout,

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public void infer(
8585
Model model,
8686
@Nullable String query,
8787
List<String> input,
88+
boolean stream,
8889
Map<String, Object> taskSettings,
8990
InputType inputType,
9091
TimeValue timeout,

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ public void infer(
8888
Model model,
8989
@Nullable String query,
9090
List<String> input,
91+
boolean stream,
9192
Map<String, Object> taskSettings,
9293
InputType inputType,
9394
TimeValue timeout,

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public void infer(
8585
Model model,
8686
String query,
8787
List<String> input,
88+
boolean stream,
8889
Map<String, Object> taskSettings,
8990
InputType inputType,
9091
TimeValue timeout,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ private void inferOnService(
114114
model,
115115
request.getQuery(),
116116
request.getInput(),
117+
request.isStreaming(),
117118
request.getTaskSettings(),
118119
request.getInputType(),
119120
request.getInferenceTimeout(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private static String getStatusCodeErrorMessage(Request request, HttpResult resu
4646
}
4747

4848
public static void checkForEmptyBody(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) {
49-
if (result.isBodyEmpty()) {
49+
if (result.isBodyEmpty() && (request.isStreaming() == false)) {
5050
String message = format("Response body was empty for request from inference entity id [%s]", request.getInferenceEntityId());
5151
throttlerManager.warn(logger, message);
5252
throw new IllegalStateException(message);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,23 @@ public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) {
2121
}
2222

2323
private final List<String> input;
24+
private final boolean stream;
2425

25-
public DocumentsOnlyInput(List<String> chunks) {
26+
public DocumentsOnlyInput(List<String> input) {
27+
this(input, false);
28+
}
29+
30+
public DocumentsOnlyInput(List<String> input, boolean stream) {
2631
super();
27-
this.input = Objects.requireNonNull(chunks);
32+
this.input = Objects.requireNonNull(input);
33+
this.stream = stream;
2834
}
2935

3036
public List<String> getInputs() {
3137
return this.input;
3238
}
39+
40+
public boolean stream() {
41+
return stream;
42+
}
3343
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,19 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
2121
}
2222

2323
private final String query;
24+
private final List<String> chunks;
25+
private final boolean stream;
26+
27+
public QueryAndDocsInputs(String query, List<String> chunks) {
28+
this(query, chunks, false);
29+
}
30+
31+
public QueryAndDocsInputs(String query, List<String> chunks, boolean stream) {
32+
super();
33+
this.query = Objects.requireNonNull(query);
34+
this.chunks = Objects.requireNonNull(chunks);
35+
this.stream = stream;
36+
}
2437

2538
public String getQuery() {
2639
return query;
@@ -30,12 +43,8 @@ public List<String> getChunks() {
3043
return chunks;
3144
}
3245

33-
List<String> chunks;
34-
35-
public QueryAndDocsInputs(String query, List<String> chunks) {
36-
super();
37-
this.query = Objects.requireNonNull(query);
38-
this.chunks = Objects.requireNonNull(chunks);
46+
public boolean stream() {
47+
return stream;
3948
}
4049

4150
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,17 @@ public void infer(
5151
Model model,
5252
@Nullable String query,
5353
List<String> input,
54+
boolean stream,
5455
Map<String, Object> taskSettings,
5556
InputType inputType,
5657
TimeValue timeout,
5758
ActionListener<InferenceServiceResults> listener
5859
) {
5960
init();
6061
if (query != null) {
61-
doInfer(model, new QueryAndDocsInputs(query, input), taskSettings, inputType, timeout, listener);
62+
doInfer(model, new QueryAndDocsInputs(query, input, stream), taskSettings, inputType, timeout, listener);
6263
} else {
63-
doInfer(model, new DocumentsOnlyInput(input), taskSettings, inputType, timeout, listener);
64+
doInfer(model, new DocumentsOnlyInput(input, stream), taskSettings, inputType, timeout, listener);
6465
}
6566
}
6667

0 commit comments

Comments
 (0)