Skip to content

Commit 80eb75a

Browse files
committed
Merge branch 'main' into skip-ds-reference-check
2 parents cc2e04c + f3447d3 commit 80eb75a

File tree

74 files changed

+305
-335
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+305
-335
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@
3737
* a single large input that has been chunked may spread over
3838
* multiple batches.
3939
*
40-
* The final aspect it to gather the responses from the batch
40+
* The final aspect is to gather the responses from the batch
4141
* processing and map the results back to the original element
4242
* in the input list.
4343
*/
4444
public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
4545

4646
// Visible for testing
47-
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<ChunkInferenceInput> inputs) {
47+
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) {
4848
public String chunkText() {
49-
return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end());
49+
return input.substring(chunk.start(), chunk.end());
5050
}
5151
}
5252

@@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
6060

6161
private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
6262

63-
// The maximum number of chunks that is stored for any input text.
63+
// The maximum number of chunks that are stored for any input text.
6464
// If the configured chunker chunks the text into more chunks, each
6565
// chunk is sent to the inference service separately, but the results
6666
// are merged so that only this maximum number of chunks is stored.
@@ -112,7 +112,8 @@ public EmbeddingRequestChunker(
112112
chunkingSettings = defaultChunkingSettings;
113113
}
114114
Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker);
115-
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings);
115+
String inputString = inputs.get(inputIndex).input();
116+
List<ChunkOffset> chunks = chunker.chunk(inputString, chunkingSettings);
116117
int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
117118
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
118119
resultOffsetStarts.add(new ArrayList<>(resultCount));
@@ -129,7 +130,7 @@ public EmbeddingRequestChunker(
129130
} else {
130131
resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
131132
}
132-
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
133+
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString));
133134
}
134135
}
135136

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,47 @@
88
package org.elasticsearch.xpack.inference.external.http.sender;
99

1010
import org.elasticsearch.core.Nullable;
11-
import org.elasticsearch.inference.ChunkInferenceInput;
12-
import org.elasticsearch.inference.ChunkingSettings;
1311
import org.elasticsearch.inference.InputType;
1412

1513
import java.util.List;
1614
import java.util.Objects;
15+
import java.util.concurrent.atomic.AtomicBoolean;
1716
import java.util.function.Supplier;
18-
import java.util.stream.Collectors;
1917

2018
public class EmbeddingsInput extends InferenceInputs {
21-
22-
public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
23-
if (inferenceInputs instanceof EmbeddingsInput == false) {
24-
throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class);
25-
}
26-
27-
return (EmbeddingsInput) inferenceInputs;
28-
}
29-
30-
private final Supplier<List<ChunkInferenceInput>> listSupplier;
19+
private final Supplier<List<String>> inputListSupplier;
3120
private final InputType inputType;
21+
private final AtomicBoolean supplierInvoked = new AtomicBoolean();
3222

33-
public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
34-
super(false);
35-
this.listSupplier = Objects.requireNonNull(inputSupplier);
36-
this.inputType = inputType;
23+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
24+
this(() -> input, inputType, false);
3725
}
3826

39-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
40-
this(input, chunkingSettings, inputType, false);
27+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
28+
this(() -> input, inputType, stream);
4129
}
4230

43-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
44-
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
31+
public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
32+
this(inputSupplier, inputType, false);
4533
}
4634

47-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
48-
this(input, inputType, false);
49-
}
50-
51-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {
35+
private EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType, boolean stream) {
5236
super(stream);
53-
Objects.requireNonNull(input);
54-
this.listSupplier = () -> input;
37+
this.inputListSupplier = Objects.requireNonNull(inputSupplier);
5538
this.inputType = inputType;
5639
}
5740

58-
public List<ChunkInferenceInput> getInputs() {
59-
return this.listSupplier.get();
60-
}
61-
62-
public static EmbeddingsInput fromStrings(List<String> input, @Nullable InputType inputType) {
63-
return new EmbeddingsInput(input, null, inputType);
64-
}
65-
66-
public List<String> getStringInputs() {
67-
return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList());
41+
/**
42+
* Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply
43+
* returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input
44+
* Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling
45+
* this method twice in a non-production environment will cause an {@link AssertionError} to be thrown.
46+
*
47+
* @return a list of String embedding inputs
48+
*/
49+
public List<String> getInputs() {
50+
assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice";
51+
return inputListSupplier.get();
6852
}
6953

7054
public InputType getInputType() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void execute(
5252
Supplier<Boolean> hasRequestCompletedFunction,
5353
ActionListener<InferenceServiceResults> listener
5454
) {
55-
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs();
55+
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
5656
var truncatedInput = truncate(docsInput, maxInputTokens);
5757
var request = requestCreator.apply(truncatedInput);
5858

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ private static InferenceInputs createInput(
105105
if (validationException.validationErrors().isEmpty() == false) {
106106
throw validationException;
107107
}
108-
yield new EmbeddingsInput(input, null, inputType, stream);
108+
yield new EmbeddingsInput(input, inputType, stream);
109109
}
110110
default -> throw new ElasticsearchStatusException(
111111
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
@@ -144,7 +144,7 @@ public void chunkedInfer(
144144
}
145145

146146
// a non-null query is not supported and is dropped by all providers
147-
doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener);
147+
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
148148
}
149149

150150
protected abstract void doInfer(
@@ -168,7 +168,7 @@ protected abstract void doUnifiedCompletionInfer(
168168

169169
protected abstract void doChunkedInfer(
170170
Model model,
171-
EmbeddingsInput inputs,
171+
List<ChunkInferenceInput> inputs,
172172
Map<String, Object> taskSettings,
173173
InputType inputType,
174174
TimeValue timeout,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.common.util.LazyInitializable;
1616
import org.elasticsearch.core.Nullable;
1717
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.inference.ChunkInferenceInput;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.InferenceServiceConfiguration;
2021
import org.elasticsearch.inference.InferenceServiceExtension;
@@ -29,7 +30,6 @@
2930
import org.elasticsearch.rest.RestStatus;
3031
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
3132
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
32-
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3333
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3434
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3535
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
@@ -143,7 +143,7 @@ protected void doUnifiedCompletionInfer(
143143
@Override
144144
protected void doChunkedInfer(
145145
Model model,
146-
EmbeddingsInput inputs,
146+
List<ChunkInferenceInput> inputs,
147147
Map<String, Object> taskSettings,
148148
InputType inputType,
149149
TimeValue timeout,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ public void execute(
7171
Supplier<Boolean> hasRequestCompletedFunction,
7272
ActionListener<InferenceServiceResults> listener
7373
) {
74-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
75-
List<String> docsInput = input.getStringInputs();
74+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
75+
76+
List<String> docsInput = input.getInputs();
7677
InputType inputType = input.getInputType();
7778

7879
AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.Nullable;
1818
import org.elasticsearch.core.Strings;
1919
import org.elasticsearch.core.TimeValue;
20+
import org.elasticsearch.inference.ChunkInferenceInput;
2021
import org.elasticsearch.inference.ChunkedInference;
2122
import org.elasticsearch.inference.ChunkingSettings;
2223
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -336,7 +337,7 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V
336337
@Override
337338
protected void doChunkedInfer(
338339
Model model,
339-
EmbeddingsInput inputs,
340+
List<ChunkInferenceInput> inputs,
340341
Map<String, Object> taskSettings,
341342
InputType inputType,
342343
TimeValue timeout,
@@ -351,14 +352,14 @@ protected void doChunkedInfer(
351352
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
352353

353354
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
354-
inputs.getInputs(),
355+
inputs,
355356
EMBEDDING_MAX_BATCH_SIZE,
356357
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
357358
).batchRequestsWithListeners(listener);
358359

359360
for (var request : batchedRequests) {
360361
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings);
361-
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
362+
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
362363
}
363364
}
364365

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ public void execute(
7171
Supplier<Boolean> hasRequestCompletedFunction,
7272
ActionListener<InferenceServiceResults> listener
7373
) {
74-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
75-
List<String> docsInput = input.getStringInputs();
74+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
75+
List<String> docsInput = input.getInputs();
7676
InputType inputType = input.getInputType();
7777

7878
AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ public void execute(
5656
Supplier<Boolean> hasRequestCompletedFunction,
5757
ActionListener<InferenceServiceResults> listener
5858
) {
59-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
60-
List<String> docsInput = input.getStringInputs();
59+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
60+
List<String> docsInput = input.getInputs();
6161
InputType inputType = input.getInputType();
6262

6363
var serviceSettings = embeddingsModel.getServiceSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.core.IOUtils;
1919
import org.elasticsearch.core.Nullable;
2020
import org.elasticsearch.core.TimeValue;
21+
import org.elasticsearch.inference.ChunkInferenceInput;
2122
import org.elasticsearch.inference.ChunkedInference;
2223
import org.elasticsearch.inference.ChunkingSettings;
2324
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -148,7 +149,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
148149
@Override
149150
protected void doChunkedInfer(
150151
Model model,
151-
EmbeddingsInput inputs,
152+
List<ChunkInferenceInput> inputs,
152153
Map<String, Object> taskSettings,
153154
InputType inputType,
154155
TimeValue timeout,
@@ -159,14 +160,14 @@ protected void doChunkedInfer(
159160
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
160161

161162
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
162-
inputs.getInputs(),
163+
inputs,
163164
maxBatchSize,
164165
baseAmazonBedrockModel.getConfigurations().getChunkingSettings()
165166
).batchRequestsWithListeners(listener);
166167

167168
for (var request : batchedRequests) {
168169
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
169-
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
170+
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
170171
}
171172
} else {
172173
listener.onFailure(createInvalidModelException(model));

0 commit comments

Comments
 (0)