Skip to content

Commit 835fa76

Browse files
authored
Merge branch '9.1' into mtv21bp91
2 parents bb6c44f + 1db5f0d commit 835fa76

File tree

69 files changed

+291
-297
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+291
-297
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@
3737
* a single large input that has been chunked may spread over
3838
* multiple batches.
3939
*
40-
* The final aspect it to gather the responses from the batch
40+
* The final aspect is to gather the responses from the batch
4141
* processing and map the results back to the original element
4242
* in the input list.
4343
*/
4444
public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
4545

4646
// Visible for testing
47-
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<ChunkInferenceInput> inputs) {
47+
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) {
4848
public String chunkText() {
49-
return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end());
49+
return input.substring(chunk.start(), chunk.end());
5050
}
5151
}
5252

@@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
6060

6161
private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
6262

63-
// The maximum number of chunks that is stored for any input text.
63+
// The maximum number of chunks that are stored for any input text.
6464
// If the configured chunker chunks the text into more chunks, each
6565
// chunk is sent to the inference service separately, but the results
6666
// are merged so that only this maximum number of chunks is stored.
@@ -112,7 +112,8 @@ public EmbeddingRequestChunker(
112112
chunkingSettings = defaultChunkingSettings;
113113
}
114114
Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker);
115-
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings);
115+
String inputString = inputs.get(inputIndex).input();
116+
List<ChunkOffset> chunks = chunker.chunk(inputString, chunkingSettings);
116117
int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
117118
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
118119
resultOffsetStarts.add(new ArrayList<>(resultCount));
@@ -129,7 +130,7 @@ public EmbeddingRequestChunker(
129130
} else {
130131
resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
131132
}
132-
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
133+
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString));
133134
}
134135
}
135136

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,47 @@
88
package org.elasticsearch.xpack.inference.external.http.sender;
99

1010
import org.elasticsearch.core.Nullable;
11-
import org.elasticsearch.inference.ChunkInferenceInput;
12-
import org.elasticsearch.inference.ChunkingSettings;
1311
import org.elasticsearch.inference.InputType;
1412

1513
import java.util.List;
1614
import java.util.Objects;
15+
import java.util.concurrent.atomic.AtomicBoolean;
1716
import java.util.function.Supplier;
18-
import java.util.stream.Collectors;
1917

2018
public class EmbeddingsInput extends InferenceInputs {
21-
22-
public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
23-
if (inferenceInputs instanceof EmbeddingsInput == false) {
24-
throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class);
25-
}
26-
27-
return (EmbeddingsInput) inferenceInputs;
28-
}
29-
30-
private final Supplier<List<ChunkInferenceInput>> listSupplier;
19+
private final Supplier<List<String>> inputListSupplier;
3120
private final InputType inputType;
21+
private final AtomicBoolean supplierInvoked = new AtomicBoolean();
3222

33-
public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
34-
super(false);
35-
this.listSupplier = Objects.requireNonNull(inputSupplier);
36-
this.inputType = inputType;
23+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
24+
this(() -> input, inputType, false);
3725
}
3826

39-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
40-
this(input, chunkingSettings, inputType, false);
27+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
28+
this(() -> input, inputType, stream);
4129
}
4230

43-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
44-
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
31+
public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
32+
this(inputSupplier, inputType, false);
4533
}
4634

47-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
48-
this(input, inputType, false);
49-
}
50-
51-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {
35+
private EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType, boolean stream) {
5236
super(stream);
53-
Objects.requireNonNull(input);
54-
this.listSupplier = () -> input;
37+
this.inputListSupplier = Objects.requireNonNull(inputSupplier);
5538
this.inputType = inputType;
5639
}
5740

58-
public List<ChunkInferenceInput> getInputs() {
59-
return this.listSupplier.get();
60-
}
61-
62-
public static EmbeddingsInput fromStrings(List<String> input, @Nullable InputType inputType) {
63-
return new EmbeddingsInput(input, null, inputType);
64-
}
65-
66-
public List<String> getStringInputs() {
67-
return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList());
41+
/**
42+
* Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply
43+
* returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input
44+
* Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling
45+
* this method twice in a non-production environment will cause an {@link AssertionError} to be thrown.
46+
*
47+
* @return a list of String embedding inputs
48+
*/
49+
public List<String> getInputs() {
50+
assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice";
51+
return inputListSupplier.get();
6852
}
6953

7054
public InputType getInputType() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void execute(
5252
Supplier<Boolean> hasRequestCompletedFunction,
5353
ActionListener<InferenceServiceResults> listener
5454
) {
55-
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs();
55+
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
5656
var truncatedInput = truncate(docsInput, maxInputTokens);
5757
var request = requestCreator.apply(truncatedInput);
5858

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ private static InferenceInputs createInput(
101101
if (validationException.validationErrors().isEmpty() == false) {
102102
throw validationException;
103103
}
104-
yield new EmbeddingsInput(input, null, inputType, stream);
104+
yield new EmbeddingsInput(input, inputType, stream);
105105
}
106106
default -> throw new ElasticsearchStatusException(
107107
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
@@ -140,7 +140,7 @@ public void chunkedInfer(
140140
}
141141

142142
// a non-null query is not supported and is dropped by all providers
143-
doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener);
143+
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
144144
}
145145

146146
protected abstract void doInfer(
@@ -164,7 +164,7 @@ protected abstract void doUnifiedCompletionInfer(
164164

165165
protected abstract void doChunkedInfer(
166166
Model model,
167-
EmbeddingsInput inputs,
167+
List<ChunkInferenceInput> inputs,
168168
Map<String, Object> taskSettings,
169169
InputType inputType,
170170
TimeValue timeout,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ public void execute(
7171
Supplier<Boolean> hasRequestCompletedFunction,
7272
ActionListener<InferenceServiceResults> listener
7373
) {
74-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
75-
List<String> docsInput = input.getStringInputs();
74+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
75+
76+
List<String> docsInput = input.getInputs();
7677
InputType inputType = input.getInputType();
7778

7879
AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.core.Nullable;
1717
import org.elasticsearch.core.Strings;
1818
import org.elasticsearch.core.TimeValue;
19+
import org.elasticsearch.inference.ChunkInferenceInput;
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -321,7 +322,7 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V
321322
@Override
322323
protected void doChunkedInfer(
323324
Model model,
324-
EmbeddingsInput inputs,
325+
List<ChunkInferenceInput> inputs,
325326
Map<String, Object> taskSettings,
326327
InputType inputType,
327328
TimeValue timeout,
@@ -336,14 +337,14 @@ protected void doChunkedInfer(
336337
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
337338

338339
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
339-
inputs.getInputs(),
340+
inputs,
340341
EMBEDDING_MAX_BATCH_SIZE,
341342
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
342343
).batchRequestsWithListeners(listener);
343344

344345
for (var request : batchedRequests) {
345346
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings);
346-
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
347+
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
347348
}
348349
}
349350

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ public void execute(
7171
Supplier<Boolean> hasRequestCompletedFunction,
7272
ActionListener<InferenceServiceResults> listener
7373
) {
74-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
75-
List<String> docsInput = input.getStringInputs();
74+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
75+
List<String> docsInput = input.getInputs();
7676
InputType inputType = input.getInputType();
7777

7878
AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ public void execute(
5656
Supplier<Boolean> hasRequestCompletedFunction,
5757
ActionListener<InferenceServiceResults> listener
5858
) {
59-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
60-
List<String> docsInput = input.getStringInputs();
59+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
60+
List<String> docsInput = input.getInputs();
6161
InputType inputType = input.getInputType();
6262

6363
var serviceSettings = embeddingsModel.getServiceSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.IOUtils;
1818
import org.elasticsearch.core.Nullable;
1919
import org.elasticsearch.core.TimeValue;
20+
import org.elasticsearch.inference.ChunkInferenceInput;
2021
import org.elasticsearch.inference.ChunkedInference;
2122
import org.elasticsearch.inference.ChunkingSettings;
2223
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -136,7 +137,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
136137
@Override
137138
protected void doChunkedInfer(
138139
Model model,
139-
EmbeddingsInput inputs,
140+
List<ChunkInferenceInput> inputs,
140141
Map<String, Object> taskSettings,
141142
InputType inputType,
142143
TimeValue timeout,
@@ -147,14 +148,14 @@ protected void doChunkedInfer(
147148
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
148149

149150
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
150-
inputs.getInputs(),
151+
inputs,
151152
maxBatchSize,
152153
baseAmazonBedrockModel.getConfigurations().getChunkingSettings()
153154
).batchRequestsWithListeners(listener);
154155

155156
for (var request : batchedRequests) {
156157
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
157-
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
158+
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
158159
}
159160
} else {
160161
listener.onFailure(createInvalidModelException(model));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.common.util.LazyInitializable;
1616
import org.elasticsearch.core.Nullable;
1717
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.inference.ChunkInferenceInput;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.InferenceServiceConfiguration;
2021
import org.elasticsearch.inference.InferenceServiceResults;
@@ -26,7 +27,6 @@
2627
import org.elasticsearch.inference.TaskType;
2728
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
2829
import org.elasticsearch.rest.RestStatus;
29-
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3030
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3131
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3232
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -222,7 +222,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
222222
@Override
223223
protected void doChunkedInfer(
224224
Model model,
225-
EmbeddingsInput inputs,
225+
List<ChunkInferenceInput> inputs,
226226
Map<String, Object> taskSettings,
227227
InputType inputType,
228228
TimeValue timeout,

0 commit comments

Comments
 (0)