diff --git a/docs/changelog/115784.yaml b/docs/changelog/115784.yaml new file mode 100644 index 0000000000000..f2431a74c312d --- /dev/null +++ b/docs/changelog/115784.yaml @@ -0,0 +1,5 @@ +pr: 115784 +summary: Add chunking to perform inference API +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 777ff083f33f8..01ab4903146e1 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -181,6 +181,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_DONT_PERSIST_ON_READ = def(8_776_00_0); public static final TransportVersion SIMULATE_MAPPING_ADDITION = def(8_777_00_0); public static final TransportVersion INTRODUCE_ALL_APPLICABLE_SELECTOR = def(8_778_00_0); + public static final TransportVersion CHUNKING_ENABLED_PERFORM_INFERENCE = def(8_779_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index a19edd5a08162..4bd30986d836d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -60,6 +60,7 @@ public static class Request extends ActionRequest { public static final ParseField INPUT = new ParseField("input"); public static final ParseField TASK_SETTINGS = new ParseField("task_settings"); public static final ParseField QUERY = new ParseField("query"); + public static final ParseField CHUNKING_ENABLED = new ParseField("chunking_enabled"); public static final ParseField TIMEOUT = new ParseField("timeout"); static final ObjectParser PARSER = new ObjectParser<>(NAME, Request.Builder::new); @@ -67,6 +68,7 @@ public static class Request extends ActionRequest { PARSER.declareStringArray(Request.Builder::setInput, INPUT); PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS); PARSER.declareString(Request.Builder::setQuery, QUERY); + PARSER.declareBoolean(Request.Builder::setChunkingEnabled, CHUNKING_ENABLED); PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT); } @@ -93,6 +95,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType, private final InputType inputType; private final TimeValue inferenceTimeout; private final boolean stream; + private final boolean chunkingEnabled; public Request( TaskType taskType, @@ -112,6 +115,29 @@ public Request( this.inputType = inputType; this.inferenceTimeout = inferenceTimeout; this.stream = stream; + this.chunkingEnabled = false; + } + + public Request( + TaskType taskType, + String inferenceEntityId, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue inferenceTimeout, + boolean stream, + boolean chunkingEnabled + ) { + this.taskType = taskType; + this.inferenceEntityId = inferenceEntityId; + this.query = query; + this.input = input; + this.taskSettings = taskSettings; + this.inputType = inputType; + this.inferenceTimeout = inferenceTimeout; + this.stream = stream; + this.chunkingEnabled = chunkingEnabled; } public Request(StreamInput in) throws IOException { @@ -138,6 +164,12 @@ public Request(StreamInput in) throws IOException { this.inferenceTimeout = DEFAULT_TIMEOUT; } + if (in.getTransportVersion().onOrAfter(TransportVersions.CHUNKING_ENABLED_PERFORM_INFERENCE)) { + this.chunkingEnabled = in.readBoolean(); + } else { + this.chunkingEnabled = false; + } + // streaming is not supported yet for transport traffic this.stream = false; } @@ -174,6 +206,10 @@ public boolean isStreaming() { return stream; } + public boolean isChunkingEnabled() { + return chunkingEnabled; + } + @Override public ActionRequestValidationException validate() { if (input == null) { @@ -201,6 +237,12 @@ public ActionRequestValidationException validate() { } } + if (chunkingEnabled && ((taskType.equals(TaskType.SPARSE_EMBEDDING) || taskType.equals(TaskType.TEXT_EMBEDDING)) == false)) { + var e = new ActionRequestValidationException(); + e.addValidationError(format("Chunking is only supported for embedding task types.")); + return e; + } + return null; } @@ -224,6 +266,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(query); out.writeTimeValue(inferenceTimeout); } + + if (out.getTransportVersion().onOrAfter(TransportVersions.CHUNKING_ENABLED_PERFORM_INFERENCE)) { + out.writeBoolean(chunkingEnabled); + } } // default for easier testing @@ -250,12 +296,13 @@ public boolean equals(Object o) { && Objects.equals(taskSettings, request.taskSettings) && Objects.equals(inputType, request.inputType) && Objects.equals(query, request.query) - && Objects.equals(inferenceTimeout, request.inferenceTimeout); + && Objects.equals(inferenceTimeout, request.inferenceTimeout) + && Objects.equals(chunkingEnabled, request.chunkingEnabled); } @Override public int hashCode() { - return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout); + return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, chunkingEnabled, inferenceTimeout); } public static class Builder { @@ -266,6 +313,7 @@ public static class Builder { private InputType inputType = InputType.UNSPECIFIED; private Map taskSettings = Map.of(); private String query; + private boolean chunkingEnabled = false; private TimeValue timeout = DEFAULT_TIMEOUT; private boolean stream = false; @@ -291,6 +339,11 @@ public Builder setQuery(String query) { return this; } + public Builder setChunkingEnabled(boolean chunkingEnabled) { + this.chunkingEnabled = chunkingEnabled; + return this; + } + public Builder setInputType(InputType inputType) { this.inputType = inputType; return this; @@ -316,7 +369,7 @@ public Builder setStream(boolean stream) { } public Request build() { - return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream); + return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, chunkingEnabled); } } @@ -335,6 +388,8 @@ public String toString() { + this.getInputType() + ", timeout=" + this.getInferenceTimeout() + + ", chunking_enabled=" + + this.isChunkingEnabled() + ")"; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/BatchedChunkedInferenceServiceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/BatchedChunkedInferenceServiceResults.java new file mode 100644 index 0000000000000..7cf0287b691e6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/BatchedChunkedInferenceServiceResults.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContent; + +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class BatchedChunkedInferenceServiceResults implements InferenceServiceResults { + public static final String NAME = "batched_chunked_inference_service_results"; + private final TaskType taskType; + private final List chunkedInferenceServiceResults; + + public BatchedChunkedInferenceServiceResults(TaskType taskType, List chunkedInferenceServiceResults) { + this.taskType = taskType; + this.chunkedInferenceServiceResults = chunkedInferenceServiceResults; + } + + public BatchedChunkedInferenceServiceResults(StreamInput in) throws IOException { + // TODO: Figure out how to do this given that you don't know the type of the chunkedInferenceServiceResults + this.taskType = in.readEnum(TaskType.class); + ; // TODO + + switch (taskType) { + case TEXT_EMBEDDING: + this.chunkedInferenceServiceResults = in.readCollectionAsList(InferenceChunkedTextEmbeddingByteResults::new); + break; + case SPARSE_EMBEDDING: + this.chunkedInferenceServiceResults = in.readCollectionAsList(InferenceChunkedSparseEmbeddingResults::new); + break; + default: + throw new IllegalArgumentException("Unknown task type: " + taskType); + } + + // TODO: What about BYTE chunked results? Seems like we don't use them anymore? Ask about this. + } + + public TaskType getTaskType() { + return taskType; + } + + public List getChunkedInferenceServiceResults() { + return chunkedInferenceServiceResults; + } + + @Override + public List transformToCoordinationFormat() { + return List.of(); + } + + @Override + public List transformToLegacyFormat() { + return List.of(); + } + + @Override + public Map asMap() { + return Map.of(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(taskType); + out.writeCollection(chunkedInferenceServiceResults, StreamOutput::writeWriteable); // TODO: Is this correct? + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return chunkedInferenceServiceResults.stream() + .map(result -> result.toXContentChunked(params)) + .reduce(Iterators::concat) + .orElseThrow(() -> new RuntimeException("TODO")); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BatchedChunkedInferenceServiceResults that = (BatchedChunkedInferenceServiceResults) o; + return Objects.equals(taskType, that.getTaskType()) + && Objects.equals(chunkedInferenceServiceResults, that.getChunkedInferenceServiceResults()); + } + + @Override + public int hashCode() { + return Objects.hash(taskType, chunkedInferenceServiceResults); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index e046e2aad463b..393fd0bd9161c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -13,6 +13,8 @@ import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; @@ -24,10 +26,12 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.BatchedChunkedInferenceServiceResults; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -83,12 +87,40 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe unparsedModel.settings(), unparsedModel.secrets() ); - inferOnService(model, request, service.get(), delegate); + if (request.isChunkingEnabled()) { + chunkedInferOnService(model, request, service.get(), delegate); + } else { + inferOnService(model, request, service.get(), delegate); + } }); modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); } + private void chunkedInferOnService( + Model model, + InferenceAction.Request request, + InferenceService service, + ActionListener listener + ) { + // TODO: Check if the if statement is necessary + if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + inferenceStats.incrementRequestCount(model); + service.chunkedInfer( + model, + request.getQuery(), + request.getInput(), + request.getTaskSettings(), + request.getInputType(), + new ChunkingOptions(null, null), + request.getInferenceTimeout(), + createChunkedListener(listener, request.getTaskType()) + ); + } else { + listener.onFailure(unsupportedStreamingTaskException(request, service)); + } + } + private void inferOnService( Model model, InferenceAction.Request request, @@ -133,6 +165,15 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference } } + private ActionListener> createChunkedListener( + ActionListener listener, + TaskType taskType + ) { + return listener.delegateFailureAndWrap((l, chunkedResults) -> { + l.onResponse(new InferenceAction.Response(new BatchedChunkedInferenceServiceResults(taskType, chunkedResults))); + }); + } + private ActionListener createListener( InferenceAction.Request request, ActionListener listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/BatchedChunkedInferenceServiceResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/BatchedChunkedInferenceServiceResultsTests.java new file mode 100644 index 0000000000000..6fbac2614bfa6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/BatchedChunkedInferenceServiceResultsTests.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.BatchedChunkedInferenceServiceResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class BatchedChunkedInferenceServiceResultsTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return BatchedChunkedInferenceServiceResults::new; + } + + @Override + protected BatchedChunkedInferenceServiceResults createTestInstance() { + var taskType = randomFrom(List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)); + return new BatchedChunkedInferenceServiceResults(taskType, createRandomChunkedInferenceServiceResults(taskType)); + } + + private List createRandomChunkedInferenceServiceResults(TaskType taskType) { + var results = new ArrayList(); + var resultsCount = randomIntBetween(1, 5); + + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + for (int i = 0; i < resultsCount; i++) { + results.add(InferenceChunkedTextEmbeddingByteResultsTests.createRandomResults()); + } + } else { + for (int i = 0; i < resultsCount; i++) { + results.add(InferenceChunkedSparseEmbeddingResultsTests.createRandomResults()); + } + } + + return results; + } + + @Override + protected BatchedChunkedInferenceServiceResults mutateInstance(BatchedChunkedInferenceServiceResults instance) throws IOException { + var taskType = randomFrom(List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)); + var results = randomValueOtherThan( + instance.getChunkedInferenceServiceResults(), + () -> createRandomChunkedInferenceServiceResults(taskType) + ); + + return new BatchedChunkedInferenceServiceResults(taskType, results); + } +}