Skip to content

Commit 74ccbb6

Browse files
committed
Remove specialized chunks
1 parent 5c209d9 commit 74ccbb6

File tree

25 files changed

+216
-225
lines changed

25 files changed

+216
-225
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;
1919

20-
public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> chunks) implements ChunkedInference {
20+
public record ChunkedInferenceEmbedding(List<EmbeddingResults.Chunk> chunks) implements ChunkedInference {
2121

2222
public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
2323
validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size());
@@ -27,10 +27,7 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding
2727
results.add(
2828
new ChunkedInferenceEmbedding(
2929
List.of(
30-
new SparseEmbeddingResults.Chunk(
31-
sparseEmbeddingResults.embeddings().get(i).tokens(),
32-
new TextOffset(0, inputs.get(i).length())
33-
)
30+
new EmbeddingResults.Chunk(sparseEmbeddingResults.embeddings().get(i), new TextOffset(0, inputs.get(i).length()))
3431
)
3532
)
3633
);
@@ -41,10 +38,10 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding
4138

4239
@Override
4340
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException {
44-
var asChunk = new ArrayList<Chunk>();
45-
for (var chunk : chunks()) {
46-
asChunk.add(chunk.toChunk(xcontent));
41+
List<Chunk> chunkedInferenceChunks = new ArrayList<>();
42+
for (EmbeddingResults.Chunk embeddingResultsChunk : chunks()) {
43+
chunkedInferenceChunks.add(new Chunk(embeddingResultsChunk.offset(), embeddingResultsChunk.embedding().toBytesRef(xcontent)));
4744
}
48-
return asChunk.iterator();
45+
return chunkedInferenceChunks.iterator();
4946
}
5047
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference.results;
99

10+
import org.elasticsearch.common.bytes.BytesReference;
1011
import org.elasticsearch.inference.ChunkedInference;
1112
import org.elasticsearch.inference.InferenceServiceResults;
1213
import org.elasticsearch.xcontent.XContent;
@@ -21,32 +22,28 @@
2122
*/
2223
public interface EmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends InferenceServiceResults {
2324

24-
/**
25-
* A resulting embedding together with the offset into the input text.
26-
*/
27-
interface Chunk {
28-
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;
29-
30-
ChunkedInference.TextOffset offset();
31-
}
32-
3325
/**
3426
* A resulting embedding for one of the input texts to the inference service.
3527
*/
3628
interface Embedding<E extends Embedding<E>> {
3729
/**
38-
* Combines the resulting embedding with the offset into the input text into a chunk.
30+
* Merges the existing embedding and provided embedding into a new embedding.
3931
*/
40-
Chunk toChunk(ChunkedInference.TextOffset offset);
32+
E merge(E embedding);
4133

4234
/**
43-
* Merges the existing embedding and provided embedding into a new embedding.
35+
* Serializes the embedding to bytes.
4436
*/
45-
E merge(E embedding);
37+
BytesReference toBytesRef(XContent xContent) throws IOException;
4638
}
4739

4840
/**
4941
* The resulting list of embeddings for the input texts to the inference service.
5042
*/
5143
List<E> embeddings();
44+
45+
/**
46+
* A resulting embedding together with the offset into the input text.
47+
*/
48+
record Chunk(Embedding<?> embedding, ChunkedInference.TextOffset offset) {}
5249
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.common.io.stream.Writeable;
1616
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
17-
import org.elasticsearch.inference.ChunkedInference;
1817
import org.elasticsearch.inference.InferenceResults;
1918
import org.elasticsearch.inference.TaskType;
2019
import org.elasticsearch.rest.RestStatus;
@@ -174,11 +173,6 @@ public String toString() {
174173
return Strings.toString(this);
175174
}
176175

177-
@Override
178-
public Chunk toChunk(ChunkedInference.TextOffset offset) {
179-
return new Chunk(tokens, offset);
180-
}
181-
182176
@Override
183177
public Embedding merge(Embedding embedding) {
184178
List<WeightedToken> mergedTokens = new ArrayList<>();
@@ -204,15 +198,9 @@ public Embedding merge(Embedding embedding) {
204198
boolean mergedIsTruncated = isTruncated || embedding.isTruncated();
205199
return new Embedding(mergedTokens, mergedIsTruncated);
206200
}
207-
}
208-
209-
public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
210-
211-
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
212-
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens));
213-
}
214201

215-
private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
202+
@Override
203+
public BytesReference toBytesRef(XContent xContent) throws IOException {
216204
XContentBuilder b = XContentBuilder.builder(xContent);
217205
b.startObject();
218206
for (var weightedToken : tokens) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import org.elasticsearch.common.io.stream.StreamOutput;
1616
import org.elasticsearch.common.io.stream.Writeable;
1717
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
18-
import org.elasticsearch.inference.ChunkedInference;
1918
import org.elasticsearch.inference.InferenceResults;
2019
import org.elasticsearch.xcontent.ToXContent;
2120
import org.elasticsearch.xcontent.ToXContentObject;
@@ -195,11 +194,6 @@ public int hashCode() {
195194
return Arrays.hashCode(values);
196195
}
197196

198-
@Override
199-
public Chunk toChunk(ChunkedInference.TextOffset offset) {
200-
return new Chunk(values, offset);
201-
}
202-
203197
@Override
204198
public Embedding merge(Embedding embedding) {
205199
byte[] newValues = new byte[values.length];
@@ -214,22 +208,13 @@ public Embedding merge(Embedding embedding) {
214208
}
215209
return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings);
216210
}
217-
}
218-
219-
/**
220-
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
221-
*/
222-
public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
223-
224-
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
225-
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
226-
}
227211

228-
private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
212+
@Override
213+
public BytesReference toBytesRef(XContent xContent) throws IOException {
229214
XContentBuilder builder = XContentBuilder.builder(xContent);
230215
builder.startArray();
231-
for (byte v : value) {
232-
builder.value(v);
216+
for (byte value : values) {
217+
builder.value(value);
233218
}
234219
builder.endArray();
235220
return BytesReference.bytes(builder);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.elasticsearch.common.io.stream.StreamOutput;
1717
import org.elasticsearch.common.io.stream.Writeable;
1818
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
19-
import org.elasticsearch.inference.ChunkedInference;
2019
import org.elasticsearch.inference.InferenceResults;
2120
import org.elasticsearch.inference.TaskType;
2221
import org.elasticsearch.rest.RestStatus;
@@ -228,11 +227,6 @@ public int hashCode() {
228227
return Arrays.hashCode(values);
229228
}
230229

231-
@Override
232-
public Chunk toChunk(ChunkedInference.TextOffset offset) {
233-
return new Chunk(values, offset);
234-
}
235-
236230
@Override
237231
public Embedding merge(Embedding embedding) {
238232
float[] mergedValues = new float[values.length];
@@ -242,22 +236,13 @@ public Embedding merge(Embedding embedding) {
242236
}
243237
return new Embedding(mergedValues, numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings);
244238
}
245-
}
246-
247-
public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
248239

249-
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
250-
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
251-
}
252-
253-
/**
254-
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
255-
*/
256-
private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException {
240+
@Override
241+
public BytesReference toBytesRef(XContent xContent) throws IOException {
257242
XContentBuilder b = XContentBuilder.builder(xContent);
258243
b.startArray();
259-
for (float v : value) {
260-
b.value(v);
244+
for (float value : values) {
245+
b.value(value);
261246
}
262247
b.endArray();
263248
return BytesReference.bytes(b);

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.xcontent.ToXContentObject;
3636
import org.elasticsearch.xcontent.XContentBuilder;
3737
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
38+
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
3839
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
3940

4041
import java.io.IOException;
@@ -181,8 +182,8 @@ private List<ChunkedInference> makeChunkedResults(List<String> input, int dimens
181182
results.add(
182183
new ChunkedInferenceEmbedding(
183184
List.of(
184-
new TextEmbeddingFloatResults.Chunk(
185-
nonChunkedResults.embeddings().get(i).values(),
185+
new EmbeddingResults.Chunk(
186+
nonChunkedResults.embeddings().get(i),
186187
new ChunkedInference.TextOffset(0, input.get(i).length())
187188
)
188189
)

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.elasticsearch.xcontent.ToXContentObject;
3434
import org.elasticsearch.xcontent.XContentBuilder;
3535
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
36+
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
3637
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3738
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
3839

@@ -172,7 +173,12 @@ private List<ChunkedInference> makeChunkedResults(List<String> input) {
172173
}
173174
results.add(
174175
new ChunkedInferenceEmbedding(
175-
List.of(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.get(i).length())))
176+
List.of(
177+
new EmbeddingResults.Chunk(
178+
new SparseEmbeddingResults.Embedding(tokens, false),
179+
new ChunkedInference.TextOffset(0, input.get(i).length())
180+
)
181+
)
176182
)
177183
);
178184
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,9 @@ private ChunkedInference mergeResultsWithInputs(int inputIndex) {
219219
AtomicReferenceArray<E> embeddings = resultEmbeddings.get(inputIndex);
220220

221221
List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
222-
223222
for (int i = 0; i < embeddings.length(); i++) {
224223
ChunkedInference.TextOffset offset = new ChunkedInference.TextOffset(startOffsets.get(i), endOffsets.get(i));
225-
chunks.add(embeddings.get(i).toChunk(offset));
224+
chunks.add(new EmbeddingResults.Chunk(embeddings.get(i), offset));
226225
}
227226
return new ChunkedInferenceEmbedding(chunks);
228227
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.rest.RestStatus;
2828
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
2929
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
30+
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
3031
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3132
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
3233
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
@@ -119,8 +120,8 @@ private static List<ChunkedInference> translateToChunkedResults(DocumentsOnlyInp
119120
results.add(
120121
new ChunkedInferenceEmbedding(
121122
List.of(
122-
new TextEmbeddingFloatResults.Chunk(
123-
textEmbeddingResults.embeddings().get(i).values(),
123+
new EmbeddingResults.Chunk(
124+
textEmbeddingResults.embeddings().get(i),
124125
new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length())
125126
)
126127
)

0 commit comments

Comments
 (0)