Skip to content

Commit 3bac931

Browse files
Add chunking to perform inference API
1 parent 755c392 commit 3bac931

File tree

5 files changed

+269
-4
lines changed

5 files changed

+269
-4
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ static TransportVersion def(int id) {
181181
public static final TransportVersion INFERENCE_DONT_PERSIST_ON_READ = def(8_776_00_0);
182182
public static final TransportVersion SIMULATE_MAPPING_ADDITION = def(8_777_00_0);
183183
public static final TransportVersion INTRODUCE_ALL_APPLICABLE_SELECTOR = def(8_778_00_0);
184+
public static final TransportVersion CHUNKING_ENABLED_PERFORM_INFERENCE = def(8_779_00_0);
184185

185186
/*
186187
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,15 @@ public static class Request extends ActionRequest {
6060
public static final ParseField INPUT = new ParseField("input");
6161
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
6262
public static final ParseField QUERY = new ParseField("query");
63+
public static final ParseField CHUNKING_ENABLED = new ParseField("chunking_enabled");
6364
public static final ParseField TIMEOUT = new ParseField("timeout");
6465

6566
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
6667
static {
6768
PARSER.declareStringArray(Request.Builder::setInput, INPUT);
6869
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
6970
PARSER.declareString(Request.Builder::setQuery, QUERY);
71+
PARSER.declareBoolean(Request.Builder::setChunkingEnabled, CHUNKING_ENABLED);
7072
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
7173
}
7274

@@ -93,6 +95,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
9395
private final InputType inputType;
9496
private final TimeValue inferenceTimeout;
9597
private final boolean stream;
98+
private final boolean chunkingEnabled;
9699

97100
public Request(
98101
TaskType taskType,
@@ -112,6 +115,29 @@ public Request(
112115
this.inputType = inputType;
113116
this.inferenceTimeout = inferenceTimeout;
114117
this.stream = stream;
118+
this.chunkingEnabled = false;
119+
}
120+
121+
public Request(
122+
TaskType taskType,
123+
String inferenceEntityId,
124+
String query,
125+
List<String> input,
126+
Map<String, Object> taskSettings,
127+
InputType inputType,
128+
TimeValue inferenceTimeout,
129+
boolean stream,
130+
boolean chunkingEnabled
131+
) {
132+
this.taskType = taskType;
133+
this.inferenceEntityId = inferenceEntityId;
134+
this.query = query;
135+
this.input = input;
136+
this.taskSettings = taskSettings;
137+
this.inputType = inputType;
138+
this.inferenceTimeout = inferenceTimeout;
139+
this.stream = stream;
140+
this.chunkingEnabled = chunkingEnabled;
115141
}
116142

117143
public Request(StreamInput in) throws IOException {
@@ -138,6 +164,12 @@ public Request(StreamInput in) throws IOException {
138164
this.inferenceTimeout = DEFAULT_TIMEOUT;
139165
}
140166

167+
if (in.getTransportVersion().onOrAfter(TransportVersions.CHUNKING_ENABLED_PERFORM_INFERENCE)) {
168+
this.chunkingEnabled = in.readBoolean();
169+
} else {
170+
this.chunkingEnabled = false;
171+
}
172+
141173
// streaming is not supported yet for transport traffic
142174
this.stream = false;
143175
}
@@ -174,6 +206,10 @@ public boolean isStreaming() {
174206
return stream;
175207
}
176208

209+
public boolean isChunkingEnabled() {
210+
return chunkingEnabled;
211+
}
212+
177213
@Override
178214
public ActionRequestValidationException validate() {
179215
if (input == null) {
@@ -201,6 +237,12 @@ public ActionRequestValidationException validate() {
201237
}
202238
}
203239

240+
if (chunkingEnabled && ((taskType.equals(TaskType.SPARSE_EMBEDDING) || taskType.equals(TaskType.TEXT_EMBEDDING)) == false)) {
241+
var e = new ActionRequestValidationException();
242+
e.addValidationError(format("Chunking is only supported for embedding task types."));
243+
return e;
244+
}
245+
204246
return null;
205247
}
206248

@@ -224,6 +266,10 @@ public void writeTo(StreamOutput out) throws IOException {
224266
out.writeOptionalString(query);
225267
out.writeTimeValue(inferenceTimeout);
226268
}
269+
270+
if (out.getTransportVersion().onOrAfter(TransportVersions.CHUNKING_ENABLED_PERFORM_INFERENCE)) {
271+
out.writeBoolean(chunkingEnabled);
272+
}
227273
}
228274

229275
// default for easier testing
@@ -250,12 +296,13 @@ public boolean equals(Object o) {
250296
&& Objects.equals(taskSettings, request.taskSettings)
251297
&& Objects.equals(inputType, request.inputType)
252298
&& Objects.equals(query, request.query)
253-
&& Objects.equals(inferenceTimeout, request.inferenceTimeout);
299+
&& Objects.equals(inferenceTimeout, request.inferenceTimeout)
300+
&& Objects.equals(chunkingEnabled, request.chunkingEnabled);
254301
}
255302

256303
@Override
257304
public int hashCode() {
258-
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout);
305+
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, chunkingEnabled, inferenceTimeout);
259306
}
260307

261308
public static class Builder {
@@ -266,6 +313,7 @@ public static class Builder {
266313
private InputType inputType = InputType.UNSPECIFIED;
267314
private Map<String, Object> taskSettings = Map.of();
268315
private String query;
316+
private boolean chunkingEnabled = false;
269317
private TimeValue timeout = DEFAULT_TIMEOUT;
270318
private boolean stream = false;
271319

@@ -291,6 +339,11 @@ public Builder setQuery(String query) {
291339
return this;
292340
}
293341

342+
public Builder setChunkingEnabled(boolean chunkingEnabled) {
343+
this.chunkingEnabled = chunkingEnabled;
344+
return this;
345+
}
346+
294347
public Builder setInputType(InputType inputType) {
295348
this.inputType = inputType;
296349
return this;
@@ -316,7 +369,7 @@ public Builder setStream(boolean stream) {
316369
}
317370

318371
public Request build() {
319-
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
372+
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, chunkingEnabled);
320373
}
321374
}
322375

@@ -335,6 +388,8 @@ public String toString() {
335388
+ this.getInputType()
336389
+ ", timeout="
337390
+ this.getInferenceTimeout()
391+
+ ", chunking_enabled="
392+
+ this.isChunkingEnabled()
338393
+ ")";
339394
}
340395
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.elasticsearch.common.collect.Iterators;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
14+
import org.elasticsearch.inference.InferenceResults;
15+
import org.elasticsearch.inference.InferenceServiceResults;
16+
import org.elasticsearch.inference.TaskType;
17+
import org.elasticsearch.xcontent.ToXContent;
18+
19+
import java.io.IOException;
20+
import java.util.Iterator;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.Objects;
24+
25+
public class BatchedChunkedInferenceServiceResults implements InferenceServiceResults {
26+
public static final String NAME = "batched_chunked_inference_service_results";
27+
private final TaskType taskType;
28+
private final List<ChunkedInferenceServiceResults> chunkedInferenceServiceResults;
29+
30+
public BatchedChunkedInferenceServiceResults(TaskType taskType, List<ChunkedInferenceServiceResults> chunkedInferenceServiceResults) {
31+
this.taskType = taskType;
32+
this.chunkedInferenceServiceResults = chunkedInferenceServiceResults;
33+
}
34+
35+
public BatchedChunkedInferenceServiceResults(StreamInput in) throws IOException {
36+
// TODO: Figure out how to do this given that you don't know the type of the chunkedInferenceServiceResults
37+
this.taskType = in.readEnum(TaskType.class);
38+
; // TODO
39+
40+
switch (taskType) {
41+
case TEXT_EMBEDDING:
42+
this.chunkedInferenceServiceResults = in.readCollectionAsList(InferenceChunkedTextEmbeddingByteResults::new);
43+
break;
44+
case SPARSE_EMBEDDING:
45+
this.chunkedInferenceServiceResults = in.readCollectionAsList(InferenceChunkedSparseEmbeddingResults::new);
46+
break;
47+
default:
48+
throw new IllegalArgumentException("Unknown task type: " + taskType);
49+
}
50+
51+
// TODO: What about BYTE chunked results? Seems like we don't use them anymore? Ask about this.
52+
}
53+
54+
public TaskType getTaskType() {
55+
return taskType;
56+
}
57+
58+
public List<ChunkedInferenceServiceResults> getChunkedInferenceServiceResults() {
59+
return chunkedInferenceServiceResults;
60+
}
61+
62+
@Override
63+
public List<? extends InferenceResults> transformToCoordinationFormat() {
64+
return List.of();
65+
}
66+
67+
@Override
68+
public List<? extends InferenceResults> transformToLegacyFormat() {
69+
return List.of();
70+
}
71+
72+
@Override
73+
public Map<String, Object> asMap() {
74+
return Map.of();
75+
}
76+
77+
@Override
78+
public String getWriteableName() {
79+
return NAME;
80+
}
81+
82+
@Override
83+
public void writeTo(StreamOutput out) throws IOException {
84+
out.writeEnum(taskType);
85+
out.writeCollection(chunkedInferenceServiceResults, StreamOutput::writeWriteable); // TODO: Is this correct?
86+
}
87+
88+
@Override
89+
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
90+
return chunkedInferenceServiceResults.stream()
91+
.map(result -> result.toXContentChunked(params))
92+
.reduce(Iterators::concat)
93+
.orElseThrow(() -> new RuntimeException("TODO"));
94+
}
95+
96+
@Override
97+
public boolean equals(Object o) {
98+
if (this == o) return true;
99+
if (o == null || getClass() != o.getClass()) return false;
100+
BatchedChunkedInferenceServiceResults that = (BatchedChunkedInferenceServiceResults) o;
101+
return Objects.equals(taskType, that.getTaskType())
102+
&& Objects.equals(chunkedInferenceServiceResults, that.getChunkedInferenceServiceResults());
103+
}
104+
105+
@Override
106+
public int hashCode() {
107+
return Objects.hash(taskType, chunkedInferenceServiceResults);
108+
}
109+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import org.elasticsearch.action.support.HandledTransportAction;
1414
import org.elasticsearch.common.util.concurrent.EsExecutors;
1515
import org.elasticsearch.common.xcontent.ChunkedToXContent;
16+
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
17+
import org.elasticsearch.inference.ChunkingOptions;
1618
import org.elasticsearch.inference.InferenceService;
1719
import org.elasticsearch.inference.InferenceServiceRegistry;
1820
import org.elasticsearch.inference.InferenceServiceResults;
@@ -24,10 +26,12 @@
2426
import org.elasticsearch.tasks.Task;
2527
import org.elasticsearch.transport.TransportService;
2628
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
29+
import org.elasticsearch.xpack.core.inference.results.BatchedChunkedInferenceServiceResults;
2730
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
2831
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2932
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
3033

34+
import java.util.List;
3135
import java.util.Set;
3236
import java.util.stream.Collectors;
3337

@@ -83,12 +87,40 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
8387
unparsedModel.settings(),
8488
unparsedModel.secrets()
8589
);
86-
inferOnService(model, request, service.get(), delegate);
90+
if (request.isChunkingEnabled()) {
91+
chunkedInferOnService(model, request, service.get(), delegate);
92+
} else {
93+
inferOnService(model, request, service.get(), delegate);
94+
}
8795
});
8896

8997
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
9098
}
9199

100+
private void chunkedInferOnService(
101+
Model model,
102+
InferenceAction.Request request,
103+
InferenceService service,
104+
ActionListener<InferenceAction.Response> listener
105+
) {
106+
// TODO: Check if the if statement is necessary
107+
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
108+
inferenceStats.incrementRequestCount(model);
109+
service.chunkedInfer(
110+
model,
111+
request.getQuery(),
112+
request.getInput(),
113+
request.getTaskSettings(),
114+
request.getInputType(),
115+
new ChunkingOptions(null, null),
116+
request.getInferenceTimeout(),
117+
createChunkedListener(listener, request.getTaskType())
118+
);
119+
} else {
120+
listener.onFailure(unsupportedStreamingTaskException(request, service));
121+
}
122+
}
123+
92124
private void inferOnService(
93125
Model model,
94126
InferenceAction.Request request,
@@ -133,6 +165,15 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
133165
}
134166
}
135167

168+
private ActionListener<List<ChunkedInferenceServiceResults>> createChunkedListener(
169+
ActionListener<InferenceAction.Response> listener,
170+
TaskType taskType
171+
) {
172+
return listener.delegateFailureAndWrap((l, chunkedResults) -> {
173+
l.onResponse(new InferenceAction.Response(new BatchedChunkedInferenceServiceResults(taskType, chunkedResults)));
174+
});
175+
}
176+
136177
private ActionListener<InferenceServiceResults> createListener(
137178
InferenceAction.Request request,
138179
ActionListener<InferenceAction.Response> listener

0 commit comments

Comments
 (0)