Skip to content

Commit 7295898

Browse files
committed
more accurate byte average
1 parent c384d5f commit 7295898

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,16 @@ public int hashCode() {
120120

121121
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
122122
// embeddings should happen inbetween serializations.
123-
public record Embedding(byte[] values, int numberOfMergedEmbeddings)
123+
public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMergedEmbeddings)
124124
implements
125125
Writeable,
126126
ToXContentObject,
127127
EmbeddingResults.Embedding<Chunk, Embedding> {
128+
128129
public static final String EMBEDDING = "embedding";
129130

130131
public Embedding(byte[] values) {
131-
this(values, 1);
132+
this(values, null, 1);
132133
}
133134

134135
public Embedding(StreamInput in) throws IOException {
@@ -201,18 +202,19 @@ public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
201202
return new Chunk(values, text, offset);
202203
}
203204

204-
// This merge function suffers from round-off errors. TODO: maybe do something smarter?
205205
@Override
206206
public Embedding merge(Embedding embedding) {
207-
byte[] mergedValues = new byte[values.length];
207+
byte[] newValues = new byte[values.length];
208+
int[] newSumMergedValues = new int[values.length];
208209
int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings;
209210
for (int i = 0; i < values.length; i++) {
211+
newSumMergedValues[i] = (numberOfMergedEmbeddings == 1 ? values[i] : sumMergedValues[i])
212+
+ (embedding.numberOfMergedEmbeddings == 1 ? embedding.values[i] : embedding.sumMergedValues[i]);
210213
// Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the
211214
// closest byte instead of truncating.
212-
mergedValues[i] = (byte) ((numberOfMergedEmbeddings * values[i] + embedding.numberOfMergedEmbeddings * embedding.values[i]
213-
+ newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
215+
newValues[i] = (byte) ((newSumMergedValues[i] + newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
214216
}
215-
return new Embedding(mergedValues, newNumberOfMergedEmbeddings);
217+
return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings);
216218
}
217219
}
218220

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.util.List;
1818
import java.util.Map;
1919

20+
import static org.hamcrest.Matchers.equalTo;
2021
import static org.hamcrest.Matchers.is;
2122

2223
public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingByteResults> {
@@ -104,16 +105,24 @@ public void testTransformToCoordinationFormat() {
104105
}
105106

106107
public void testGetFirstEmbeddingSize() {
107-
var firstEmbeddingSize = new TextEmbeddingByteResults(
108-
List.of(
109-
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
110-
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
111-
)
112-
).getFirstEmbeddingSize();
108+
var firstEmbeddingSize = new TextEmbeddingByteResults(List.of(
109+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
110+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
111+
)).getFirstEmbeddingSize();
113112

114113
assertThat(firstEmbeddingSize, is(2));
115114
}
116115

116+
public void testEmbeddingMerge() {
117+
TextEmbeddingByteResults.Embedding embedding1 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 });
118+
TextEmbeddingByteResults.Embedding embedding2 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 });
119+
TextEmbeddingByteResults.Embedding embedding3 = new TextEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 });
120+
TextEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
121+
assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 })));
122+
mergedEmbedding = mergedEmbedding.merge(embedding3);
123+
assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 })));
124+
}
125+
117126
@Override
118127
protected Writeable.Reader<TextEmbeddingByteResults> instanceReader() {
119128
return TextEmbeddingByteResults::new;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ public void testVeryLongInput_Byte() {
461461
assertThat(embedding.chunks(), hasSize(512));
462462

463463
// The first merged chunk consists of 20 small chunks (so 400 words) and the weight
464-
// is the average of the weights 2 ... 21, with some round-off errors.
464+
// is the average of the weights 2 ... 21, so 11.5, which is rounded to 12.
465465
assertThat(embedding.chunks().get(0).matchedText(), startsWith("word0 word1 "));
466466
assertThat(embedding.chunks().get(0).matchedText(), endsWith(" word398 word399"));
467467
assertThat(embedding.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
@@ -470,7 +470,7 @@ public void testVeryLongInput_Byte() {
470470

471471
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
472472
// is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so
473-
// the average of -1, 0, 1, ... , 17, with some round-off errors.
473+
// the average of -1, 0, 1, ... , 17, so 8.
474474
assertThat(embedding.chunks().get(511).matchedText(), startsWith(" word199620 word199621 "));
475475
assertThat(embedding.chunks().get(511).matchedText(), endsWith(" word199998 word199999"));
476476
assertThat(embedding.chunks().get(511), instanceOf(TextEmbeddingByteResults.Chunk.class));

0 commit comments

Comments
 (0)