Skip to content

Commit aee541b

Browse files
committed
bytes merge function
1 parent bc6074b commit aee541b

File tree

2 files changed

+107
-4
lines changed

2 files changed

+107
-4
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,19 @@ public int hashCode() {
118118
return Objects.hash(embeddings);
119119
}
120120

121-
public record Embedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding<Chunk, Embedding> {
121+
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
122+
// embeddings should happen inbetween serializations.
123+
public record Embedding(byte[] values, int numberOfMergedEmbeddings)
124+
implements
125+
Writeable,
126+
ToXContentObject,
127+
EmbeddingResults.Embedding<Chunk, Embedding> {
122128
public static final String EMBEDDING = "embedding";
123129

130+
public Embedding(byte[] values) {
131+
this(values, 1);
132+
}
133+
124134
public Embedding(StreamInput in) throws IOException {
125135
this(in.readByteArray());
126136
}
@@ -191,9 +201,18 @@ public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
191201
return new Chunk(values, text, offset);
192202
}
193203

204+
// This merge function suffers from round-off errors. TODO: maybe do something smarter?
194205
@Override
195206
public Embedding merge(Embedding embedding) {
196-
throw new UnsupportedOperationException();
207+
byte[] mergedValues = new byte[values.length];
208+
int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings;
209+
for (int i = 0; i < values.length; i++) {
210+
// Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the
211+
// closest byte instead of truncating.
212+
mergedValues[i] = (byte) ((numberOfMergedEmbeddings * values[i] + embedding.numberOfMergedEmbeddings * embedding.values[i]
213+
+ newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
214+
}
215+
return new Embedding(mergedValues, newNumberOfMergedEmbeddings);
197216
}
198217
}
199218

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,7 @@ public void testVeryLongInput_Float() {
342342
}
343343
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
344344

345-
// Produce inference results for each request, with just the token
346-
// "word" and increasing weights.
345+
// Produce inference results for each request, with increasing weights.
347346
float weight = 0f;
348347
for (var batch : batches) {
349348
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
@@ -403,6 +402,91 @@ public void testVeryLongInput_Float() {
403402
assertThat(chunk.embedding(), equalTo(new float[] { 10002 / 16384f }));
404403
}
405404

405+
public void testVeryLongInput_Byte() {
406+
int batchSize = 5;
407+
int chunkSize = 20;
408+
int numberOfWordsInPassage = (chunkSize * 10000);
409+
410+
var passageBuilder = new StringBuilder();
411+
for (int i = 0; i < numberOfWordsInPassage; i++) {
412+
passageBuilder.append("word").append(i).append(" "); // chunk on whitespace
413+
}
414+
415+
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small");
416+
417+
var finalListener = testListener();
418+
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0)
419+
.batchRequestsWithListeners(finalListener);
420+
421+
// The very long passage is split into 10000 chunks for inference, so
422+
// there are 10002 inference requests, resulting in 2001 batches.
423+
assertThat(batches, hasSize(2001));
424+
for (int i = 0; i < 2000; i++) {
425+
assertThat(batches.get(i).batch().inputs(), hasSize(5));
426+
}
427+
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
428+
429+
// Produce inference results for each request, with increasing weights.
430+
byte weight = 0;
431+
for (var batch : batches) {
432+
var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>();
433+
for (int i = 0; i < batch.batch().requests().size(); i++) {
434+
weight += 1;
435+
embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { weight }));
436+
}
437+
batch.listener().onResponse(new TextEmbeddingByteResults(embeddings));
438+
}
439+
440+
assertNotNull(finalListener.results);
441+
assertThat(finalListener.results, hasSize(3));
442+
443+
// The first input has the embedding with weight 1.
444+
ChunkedInference inference = finalListener.results.get(0);
445+
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
446+
ChunkedInferenceEmbedding embedding = (ChunkedInferenceEmbedding) inference;
447+
assertThat(embedding.chunks(), hasSize(1));
448+
assertThat(embedding.chunks().get(0).matchedText(), equalTo("1st small"));
449+
assertThat(embedding.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
450+
TextEmbeddingByteResults.Chunk chunk = (TextEmbeddingByteResults.Chunk) embedding.chunks().get(0);
451+
assertThat(chunk.embedding(), equalTo(new byte[] { 1 }));
452+
453+
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
454+
// inference. They get the embeddings with weights 2/1024 ... 10000/16384.
455+
// Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks
456+
// and therefore 380 or 400 words. For each, the average weight is collected.
457+
inference = finalListener.results.get(1);
458+
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
459+
embedding = (ChunkedInferenceEmbedding) inference;
460+
assertThat(embedding.chunks(), hasSize(512));
461+
462+
// The first merged chunk consists of 20 small chunks (so 400 words) and the weight
463+
// is the average of the weights 2 ... 21, with some round-off errors.
464+
assertThat(embedding.chunks().get(0).matchedText(), startsWith("word0 word1 "));
465+
assertThat(embedding.chunks().get(0).matchedText(), endsWith(" word398 word399"));
466+
assertThat(embedding.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
467+
chunk = (TextEmbeddingByteResults.Chunk) embedding.chunks().get(0);
468+
assertThat(chunk.embedding(), equalTo(new byte[] { 12 }));
469+
470+
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
471+
// is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so
472+
// the average of -1, 0, 1, ... , 17, with some round-off errors.
473+
assertThat(embedding.chunks().get(511).matchedText(), startsWith(" word199620 word199621 "));
474+
assertThat(embedding.chunks().get(511).matchedText(), endsWith(" word199998 word199999"));
475+
assertThat(embedding.chunks().get(511), instanceOf(TextEmbeddingByteResults.Chunk.class));
476+
chunk = (TextEmbeddingByteResults.Chunk) embedding.chunks().get(511);
477+
assertThat(chunk.embedding(), equalTo(new byte[] { 8 }));
478+
479+
// The last input has the token with weight 10002 % 256 = 18
480+
inference = finalListener.results.get(2);
481+
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
482+
embedding = (ChunkedInferenceEmbedding) inference;
483+
assertThat(embedding.chunks(), hasSize(1));
484+
assertThat(embedding.chunks().get(0).matchedText(), equalTo("2nd small"));
485+
assertThat(embedding.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
486+
chunk = (TextEmbeddingByteResults.Chunk) embedding.chunks().get(0);
487+
assertThat(chunk.embedding(), equalTo(new byte[] { 18 }));
488+
}
489+
406490
public void testMergingListener_Float() {
407491
int batchSize = 5;
408492
int chunkSize = 20;

0 commit comments

Comments
 (0)