Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@
* a single large input that has been chunked may spread over
* multiple batches.
*
* The final aspect it to gather the responses from the batch
* The final aspect is to gather the responses from the batch
* processing and map the results back to the original element
* in the input list.
*/
public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {

// Visible for testing
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<ChunkInferenceInput> inputs) {
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) {
public String chunkText() {
return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end());
return input.substring(chunk.start(), chunk.end());
}
}

Expand All @@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen

private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100);

// The maximum number of chunks that is stored for any input text.
// The maximum number of chunks that are stored for any input text.
// If the configured chunker chunks the text into more chunks, each
// chunk is sent to the inference service separately, but the results
// are merged so that only this maximum number of chunks is stored.
Expand Down Expand Up @@ -112,7 +112,8 @@ public EmbeddingRequestChunker(
chunkingSettings = defaultChunkingSettings;
}
Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker);
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings);
String inputString = inputs.get(inputIndex).input();
List<ChunkOffset> chunks = chunker.chunk(inputString, chunkingSettings);
int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
resultOffsetStarts.add(new ArrayList<>(resultCount));
Expand All @@ -129,7 +130,7 @@ public EmbeddingRequestChunker(
} else {
resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
}
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,47 @@
package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class EmbeddingsInput extends InferenceInputs {

public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
if (inferenceInputs instanceof EmbeddingsInput == false) {
throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class);
}

return (EmbeddingsInput) inferenceInputs;
}

private final Supplier<List<ChunkInferenceInput>> listSupplier;
private final Supplier<List<String>> inputListSupplier;
private final InputType inputType;
private final AtomicBoolean supplierInvoked = new AtomicBoolean();

public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
super(false);
this.listSupplier = Objects.requireNonNull(inputSupplier);
this.inputType = inputType;
public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
this(() -> input, inputType, false);
}

public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
this(input, chunkingSettings, inputType, false);
public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
this(() -> input, inputType, stream);
}

public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
this(inputSupplier, inputType, false);
}

public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
this(input, inputType, false);
}

public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {
private EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType, boolean stream) {
super(stream);
Objects.requireNonNull(input);
this.listSupplier = () -> input;
this.inputListSupplier = Objects.requireNonNull(inputSupplier);
this.inputType = inputType;
}

public List<ChunkInferenceInput> getInputs() {
return this.listSupplier.get();
}

public static EmbeddingsInput fromStrings(List<String> input, @Nullable InputType inputType) {
return new EmbeddingsInput(input, null, inputType);
}

public List<String> getStringInputs() {
return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList());
/**
* Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply
* returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input
* Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling
* this method twice in a non-production environment will cause an {@link AssertionError} to be thrown.
*
* @return a list of String embedding inputs
*/
public List<String> getInputs() {
assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice";
return inputListSupplier.get();
}

public InputType getInputType() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs();
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
var truncatedInput = truncate(docsInput, maxInputTokens);
var request = requestCreator.apply(truncatedInput);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ private static InferenceInputs createInput(
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
yield new EmbeddingsInput(input, null, inputType, stream);
yield new EmbeddingsInput(input, inputType, stream);
}
default -> throw new ElasticsearchStatusException(
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
Expand Down Expand Up @@ -140,7 +140,7 @@ public void chunkedInfer(
}

// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener);
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
}

protected abstract void doInfer(
Expand All @@ -164,7 +164,7 @@ protected abstract void doUnifiedCompletionInfer(

protected abstract void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);

List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
Expand Down Expand Up @@ -321,7 +322,7 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -336,14 +337,14 @@ protected void doChunkedInfer(
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
inputs,
EMBEDDING_MAX_BATCH_SIZE,
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

var serviceSettings = embeddingsModel.getServiceSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
Expand Down Expand Up @@ -136,7 +137,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -147,14 +148,14 @@ protected void doChunkedInfer(
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
inputs,
maxBatchSize,
baseAmazonBedrockModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
}
} else {
listener.onFailure(createInvalidModelException(model));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
Expand All @@ -26,7 +27,6 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
Expand Down Expand Up @@ -222,7 +222,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
Expand Down Expand Up @@ -123,7 +124,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -133,14 +134,14 @@ protected void doChunkedInfer(
var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
inputs,
EMBEDDING_MAX_BATCH_SIZE,
baseAzureAiStudioModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
}
} else {
listener.onFailure(createInvalidModelException(model));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
Expand Down Expand Up @@ -272,7 +273,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -286,14 +287,14 @@ protected void doChunkedInfer(
var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
inputs,
EMBEDDING_MAX_BATCH_SIZE,
azureOpenAiModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = azureOpenAiModel.accept(actionCreator, taskSettings);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
}
}

Expand Down
Loading