Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/115784.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 115784
summary: Add chunking to perform inference API
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INFERENCE_DONT_PERSIST_ON_READ = def(8_776_00_0);
public static final TransportVersion SIMULATE_MAPPING_ADDITION = def(8_777_00_0);
public static final TransportVersion INTRODUCE_ALL_APPLICABLE_SELECTOR = def(8_778_00_0);
public static final TransportVersion CHUNKING_ENABLED_PERFORM_INFERENCE = def(8_779_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@ public static class Request extends ActionRequest {
public static final ParseField INPUT = new ParseField("input");
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
public static final ParseField QUERY = new ParseField("query");
public static final ParseField CHUNKING_ENABLED = new ParseField("chunking_enabled");
public static final ParseField TIMEOUT = new ParseField("timeout");

static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
static {
PARSER.declareStringArray(Request.Builder::setInput, INPUT);
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
PARSER.declareString(Request.Builder::setQuery, QUERY);
PARSER.declareBoolean(Request.Builder::setChunkingEnabled, CHUNKING_ENABLED);
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
}

Expand All @@ -93,6 +95,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
private final InputType inputType;
private final TimeValue inferenceTimeout;
private final boolean stream;
private final boolean chunkingEnabled;

public Request(
TaskType taskType,
Expand All @@ -112,6 +115,29 @@ public Request(
this.inputType = inputType;
this.inferenceTimeout = inferenceTimeout;
this.stream = stream;
this.chunkingEnabled = false;
}

public Request(
TaskType taskType,
String inferenceEntityId,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue inferenceTimeout,
boolean stream,
boolean chunkingEnabled
) {
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.query = query;
this.input = input;
this.taskSettings = taskSettings;
this.inputType = inputType;
this.inferenceTimeout = inferenceTimeout;
this.stream = stream;
this.chunkingEnabled = chunkingEnabled;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -138,6 +164,12 @@ public Request(StreamInput in) throws IOException {
this.inferenceTimeout = DEFAULT_TIMEOUT;
}

if (in.getTransportVersion().onOrAfter(TransportVersions.CHUNKING_ENABLED_PERFORM_INFERENCE)) {
this.chunkingEnabled = in.readBoolean();
} else {
this.chunkingEnabled = false;
}

// streaming is not supported yet for transport traffic
this.stream = false;
}
Expand Down Expand Up @@ -174,6 +206,10 @@ public boolean isStreaming() {
return stream;
}

public boolean isChunkingEnabled() {
return chunkingEnabled;
}

@Override
public ActionRequestValidationException validate() {
if (input == null) {
Expand Down Expand Up @@ -201,6 +237,12 @@ public ActionRequestValidationException validate() {
}
}

if (chunkingEnabled && ((taskType.equals(TaskType.SPARSE_EMBEDDING) || taskType.equals(TaskType.TEXT_EMBEDDING)) == false)) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Chunking is only supported for embedding task types."));
return e;
}

return null;
}

Expand All @@ -224,6 +266,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(query);
out.writeTimeValue(inferenceTimeout);
}

if (out.getTransportVersion().onOrAfter(TransportVersions.CHUNKING_ENABLED_PERFORM_INFERENCE)) {
out.writeBoolean(chunkingEnabled);
}
}

// default for easier testing
Expand All @@ -250,12 +296,13 @@ public boolean equals(Object o) {
&& Objects.equals(taskSettings, request.taskSettings)
&& Objects.equals(inputType, request.inputType)
&& Objects.equals(query, request.query)
&& Objects.equals(inferenceTimeout, request.inferenceTimeout);
&& Objects.equals(inferenceTimeout, request.inferenceTimeout)
&& Objects.equals(chunkingEnabled, request.chunkingEnabled);
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout);
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, chunkingEnabled, inferenceTimeout);
}

public static class Builder {
Expand All @@ -266,6 +313,7 @@ public static class Builder {
private InputType inputType = InputType.UNSPECIFIED;
private Map<String, Object> taskSettings = Map.of();
private String query;
private boolean chunkingEnabled = false;
private TimeValue timeout = DEFAULT_TIMEOUT;
private boolean stream = false;

Expand All @@ -291,6 +339,11 @@ public Builder setQuery(String query) {
return this;
}

public Builder setChunkingEnabled(boolean chunkingEnabled) {
this.chunkingEnabled = chunkingEnabled;
return this;
}

public Builder setInputType(InputType inputType) {
this.inputType = inputType;
return this;
Expand All @@ -316,7 +369,7 @@ public Builder setStream(boolean stream) {
}

public Request build() {
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, chunkingEnabled);
}
}

Expand All @@ -335,6 +388,8 @@ public String toString() {
+ this.getInputType()
+ ", timeout="
+ this.getInferenceTimeout()
+ ", chunking_enabled="
+ this.isChunkingEnabled()
+ ")";
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContent;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class BatchedChunkedInferenceServiceResults implements InferenceServiceResults {
public static final String NAME = "batched_chunked_inference_service_results";
private final TaskType taskType;
private final List<ChunkedInferenceServiceResults> chunkedInferenceServiceResults;

public BatchedChunkedInferenceServiceResults(TaskType taskType, List<ChunkedInferenceServiceResults> chunkedInferenceServiceResults) {
this.taskType = taskType;
this.chunkedInferenceServiceResults = chunkedInferenceServiceResults;
}

public BatchedChunkedInferenceServiceResults(StreamInput in) throws IOException {
// TODO: Figure out how to do this given that you don't know the type of the chunkedInferenceServiceResults
this.taskType = in.readEnum(TaskType.class);
; // TODO

switch (taskType) {
case TEXT_EMBEDDING:
this.chunkedInferenceServiceResults = in.readCollectionAsList(InferenceChunkedTextEmbeddingByteResults::new);
break;
case SPARSE_EMBEDDING:
this.chunkedInferenceServiceResults = in.readCollectionAsList(InferenceChunkedSparseEmbeddingResults::new);
break;
default:
throw new IllegalArgumentException("Unknown task type: " + taskType);
}

// TODO: What about BYTE chunked results? Seems like we don't use them anymore? Ask about this.
}

public TaskType getTaskType() {
return taskType;
}

public List<ChunkedInferenceServiceResults> getChunkedInferenceServiceResults() {
return chunkedInferenceServiceResults;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return List.of();
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
return List.of();
}

@Override
public Map<String, Object> asMap() {
return Map.of();
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(taskType);
out.writeCollection(chunkedInferenceServiceResults, StreamOutput::writeWriteable); // TODO: Is this correct?
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return chunkedInferenceServiceResults.stream()
.map(result -> result.toXContentChunked(params))
.reduce(Iterators::concat)
.orElseThrow(() -> new RuntimeException("TODO"));
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BatchedChunkedInferenceServiceResults that = (BatchedChunkedInferenceServiceResults) o;
return Objects.equals(taskType, that.getTaskType())
&& Objects.equals(chunkedInferenceServiceResults, that.getChunkedInferenceServiceResults());
}

@Override
public int hashCode() {
return Objects.hash(taskType, chunkedInferenceServiceResults);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
Expand All @@ -24,10 +26,12 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.BatchedChunkedInferenceServiceResults;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -83,12 +87,40 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
unparsedModel.settings(),
unparsedModel.secrets()
);
inferOnService(model, request, service.get(), delegate);
if (request.isChunkingEnabled()) {
chunkedInferOnService(model, request, service.get(), delegate);
} else {
inferOnService(model, request, service.get(), delegate);
}
});

modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
}

private void chunkedInferOnService(
Model model,
InferenceAction.Request request,
InferenceService service,
ActionListener<InferenceAction.Response> listener
) {
// TODO: Check if the if statement is necessary
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
inferenceStats.incrementRequestCount(model);
service.chunkedInfer(
model,
request.getQuery(),
request.getInput(),
request.getTaskSettings(),
request.getInputType(),
new ChunkingOptions(null, null),
request.getInferenceTimeout(),
createChunkedListener(listener, request.getTaskType())
);
} else {
listener.onFailure(unsupportedStreamingTaskException(request, service));
}
}

private void inferOnService(
Model model,
InferenceAction.Request request,
Expand Down Expand Up @@ -133,6 +165,15 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
}
}

private ActionListener<List<ChunkedInferenceServiceResults>> createChunkedListener(
ActionListener<InferenceAction.Response> listener,
TaskType taskType
) {
return listener.delegateFailureAndWrap((l, chunkedResults) -> {
l.onResponse(new InferenceAction.Response(new BatchedChunkedInferenceServiceResults(taskType, chunkedResults)));
});
}

private ActionListener<InferenceServiceResults> createListener(
InferenceAction.Request request,
ActionListener<InferenceAction.Response> listener
Expand Down
Loading