Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@
* a single large input that has been chunked may spread over
* multiple batches.
*
* The final aspect it to gather the responses from the batch
* The final aspect is to gather the responses from the batch
* processing and map the results back to the original element
* in the input list.
*/
public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {

// Visible for testing
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<ChunkInferenceInput> inputs) {
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) {
public String chunkText() {
return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end());
return input.substring(chunk.start(), chunk.end());
}
}

Expand All @@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen

private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100);

// The maximum number of chunks that is stored for any input text.
// The maximum number of chunks that are stored for any input text.
// If the configured chunker chunks the text into more chunks, each
// chunk is sent to the inference service separately, but the results
// are merged so that only this maximum number of chunks is stored.
Expand Down Expand Up @@ -112,7 +112,8 @@ public EmbeddingRequestChunker(
chunkingSettings = defaultChunkingSettings;
}
Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker);
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings);
String inputString = inputs.get(inputIndex).input();
List<ChunkOffset> chunks = chunker.chunk(inputString, chunkingSettings);
int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
resultOffsetStarts.add(new ArrayList<>(resultCount));
Expand All @@ -129,7 +130,7 @@ public EmbeddingRequestChunker(
} else {
resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
}
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,40 @@
package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class EmbeddingsInput extends InferenceInputs {

public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
if (inferenceInputs instanceof EmbeddingsInput == false) {
throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class);
}

return (EmbeddingsInput) inferenceInputs;
}

private final Supplier<List<ChunkInferenceInput>> listSupplier;
private Supplier<List<String>> inputListSupplier;
private final InputType inputType;

public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
super(false);
this.listSupplier = Objects.requireNonNull(inputSupplier);
this.inputType = inputType;
public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
this(() -> input, inputType, false);
}

public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
this(input, chunkingSettings, inputType, false);
public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
this(() -> input, inputType, stream);
}

public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
this(inputSupplier, inputType, false);
}

public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
this(input, inputType, false);
}

public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {
private EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType, boolean stream) {
super(stream);
Objects.requireNonNull(input);
this.listSupplier = () -> input;
this.inputListSupplier = Objects.requireNonNull(inputSupplier);
this.inputType = inputType;
}

public List<ChunkInferenceInput> getInputs() {
return this.listSupplier.get();
}

public static EmbeddingsInput fromStrings(List<String> input, @Nullable InputType inputType) {
return new EmbeddingsInput(input, null, inputType);
}

public List<String> getStringInputs() {
return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList());
public List<String> getInputs() {
// The supplier should only be invoked once
assert inputListSupplier != null;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assert is removed in production so if this is called twice I believe it'll produce a null pointer. What if we wrapped the call in an atomic boolean and did a compareAndSet and if it has already been called we through an exception explaining why that's not allowed?

I assume we don't want to cache the result of inputListSupplier.get() in memory within this class right? If we were ok with that then we could do something with an atomic reference and locking to ensure it only gets set once and just returned after that.

This probably overkill but I think we only care about the situation when the class is initialized with a supplier. If the constructor were passed a List<String>, I don't think we need to guard against that supplier we create being called multiple times. If we did want to allow multiple calls in that scenario we could use a custom internal class that wraps the List<String> and allows multiple calls. For the supplier scenario we'd have a different internal class that does the atomic boolean check and throws.

List<String> strings = inputListSupplier.get();
inputListSupplier = null;
return strings;
}

public InputType getInputType() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs();
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
var truncatedInput = truncate(docsInput, maxInputTokens);
var request = requestCreator.apply(truncatedInput);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ private static InferenceInputs createInput(
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
yield new EmbeddingsInput(input, null, inputType, stream);
yield new EmbeddingsInput(input, inputType, stream);
}
default -> throw new ElasticsearchStatusException(
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
Expand Down Expand Up @@ -144,7 +144,7 @@ public void chunkedInfer(
}

// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener);
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
}

protected abstract void doInfer(
Expand All @@ -168,7 +168,7 @@ protected abstract void doUnifiedCompletionInfer(

protected abstract void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
Expand All @@ -29,7 +30,6 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
Expand Down Expand Up @@ -143,7 +143,7 @@ protected void doUnifiedCompletionInfer(
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);

List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
Expand Down Expand Up @@ -336,7 +337,7 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -351,14 +352,14 @@ protected void doChunkedInfer(
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
inputs,
EMBEDDING_MAX_BATCH_SIZE,
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

var serviceSettings = embeddingsModel.getServiceSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
Expand Down Expand Up @@ -148,7 +149,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -159,14 +160,14 @@ protected void doChunkedInfer(
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
inputs,
maxBatchSize,
baseAmazonBedrockModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
}
} else {
listener.onFailure(createInvalidModelException(model));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
Expand All @@ -28,7 +29,6 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
Expand Down Expand Up @@ -232,7 +232,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
Expand Down Expand Up @@ -135,7 +136,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
@Override
protected void doChunkedInfer(
Model model,
EmbeddingsInput inputs,
List<ChunkInferenceInput> inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -145,14 +146,14 @@ protected void doChunkedInfer(
var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
inputs,
EMBEDDING_MAX_BATCH_SIZE,
baseAzureAiStudioModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
}
} else {
listener.onFailure(createInvalidModelException(model));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
List<String> docsInput = input.getStringInputs();
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
List<String> docsInput = input.getInputs();
InputType inputType = input.getInputType();

var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());
Expand Down
Loading