Skip to content

Commit 786c9cb

Browse files
committed
more accurate byte average
1 parent 1140130 commit 786c9cb

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,16 @@ public int hashCode() {
120120

121121
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
122122
// embeddings should happen inbetween serializations.
123-
public record Embedding(byte[] values, int numberOfMergedEmbeddings)
123+
public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMergedEmbeddings)
124124
implements
125125
Writeable,
126126
ToXContentObject,
127127
EmbeddingResults.Embedding<Chunk, Embedding> {
128+
128129
public static final String EMBEDDING = "embedding";
129130

130131
public Embedding(byte[] values) {
131-
this(values, 1);
132+
this(values, null, 1);
132133
}
133134

134135
public Embedding(StreamInput in) throws IOException {
@@ -201,18 +202,19 @@ public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
201202
return new Chunk(values, text, offset);
202203
}
203204

204-
// This merge function suffers from round-off errors. TODO: maybe do something smarter?
205205
@Override
206206
public Embedding merge(Embedding embedding) {
207-
byte[] mergedValues = new byte[values.length];
207+
byte[] newValues = new byte[values.length];
208+
int[] newSumMergedValues = new int[values.length];
208209
int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings;
209210
for (int i = 0; i < values.length; i++) {
211+
newSumMergedValues[i] = (numberOfMergedEmbeddings == 1 ? values[i] : sumMergedValues[i])
212+
+ (embedding.numberOfMergedEmbeddings == 1 ? embedding.values[i] : embedding.sumMergedValues[i]);
210213
// Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the
211214
// closest byte instead of truncating.
212-
mergedValues[i] = (byte) ((numberOfMergedEmbeddings * values[i] + embedding.numberOfMergedEmbeddings * embedding.values[i]
213-
+ newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
215+
newValues[i] = (byte) ((newSumMergedValues[i] + newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
214216
}
215-
return new Embedding(mergedValues, newNumberOfMergedEmbeddings);
217+
return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings);
216218
}
217219
}
218220

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.util.List;
1818
import java.util.Map;
1919

20+
import static org.hamcrest.Matchers.equalTo;
2021
import static org.hamcrest.Matchers.is;
2122

2223
public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingByteResults> {
@@ -103,6 +104,16 @@ public void testTransformToCoordinationFormat() {
103104
);
104105
}
105106

107+
public void testEmbeddingMerge() {
108+
TextEmbeddingByteResults.Embedding embedding1 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 });
109+
TextEmbeddingByteResults.Embedding embedding2 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 });
110+
TextEmbeddingByteResults.Embedding embedding3 = new TextEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 });
111+
TextEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
112+
assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 })));
113+
mergedEmbedding = mergedEmbedding.merge(embedding3);
114+
assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 })));
115+
}
116+
106117
@Override
107118
protected Writeable.Reader<TextEmbeddingByteResults> instanceReader() {
108119
return TextEmbeddingByteResults::new;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ public void testVeryLongInput_Byte() {
460460
assertThat(embedding.chunks(), hasSize(512));
461461

462462
// The first merged chunk consists of 20 small chunks (so 400 words) and the weight
463-
// is the average of the weights 2 ... 21, with some round-off errors.
463+
// is the average of the weights 2 ... 21, so 11.5, which is rounded to 12.
464464
assertThat(embedding.chunks().get(0).matchedText(), startsWith("word0 word1 "));
465465
assertThat(embedding.chunks().get(0).matchedText(), endsWith(" word398 word399"));
466466
assertThat(embedding.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
@@ -469,7 +469,7 @@ public void testVeryLongInput_Byte() {
469469

470470
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
471471
// is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so
472-
// the average of -1, 0, 1, ... , 17, with some round-off errors.
472+
// the average of -1, 0, 1, ... , 17, so 8.
473473
assertThat(embedding.chunks().get(511).matchedText(), startsWith(" word199620 word199621 "));
474474
assertThat(embedding.chunks().get(511).matchedText(), endsWith(" word199998 word199999"));
475475
assertThat(embedding.chunks().get(511), instanceOf(TextEmbeddingByteResults.Chunk.class));

0 commit comments

Comments
 (0)