Skip to content

Commit 9f42da9

Browse files
committed
Add basic implementations of float-byte script comparisons
1 parent 49352fd commit 9f42da9

File tree

8 files changed

+73
-23
lines changed

8 files changed

+73
-23
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ public static float ipFloatBit(float[] q, byte[] d) {
8080
return IMPL.ipFloatBit(q, d);
8181
}
8282

83+
/**
84+
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a byte vector.
85+
* @param q the query vector
86+
* @param d the document vector
87+
* @return the inner product of the two vectors
88+
*/
89+
public static float ipFloatByte(float[] q, byte[] d) {
90+
if (q.length != d.length) {
91+
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + d.length);
92+
}
93+
return IMPL.ipFloatByte(q, d);
94+
}
95+
8396
/**
8497
* AND bit count computed over signed bytes.
8598
* Copied from Lucene's XOR implementation

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ public float ipFloatBit(float[] q, byte[] d) {
3939
return ipFloatBitImpl(q, d);
4040
}
4141

42+
@Override
43+
public float ipFloatByte(float[] q, byte[] d) {
44+
return ipFloatByteImpl(q, d);
45+
}
46+
4247
public static int ipByteBitImpl(byte[] q, byte[] d) {
4348
assert q.length == d.length * Byte.SIZE;
4449
int acc0 = 0;
@@ -101,4 +106,12 @@ public static long ipByteBinByteImpl(byte[] q, byte[] d) {
101106
}
102107
return ret;
103108
}
109+
110+
public static float ipFloatByteImpl(float[] q, byte[] d) {
111+
float ret = 0;
112+
for (int i = 0; i < q.length; i++) {
113+
ret += q[i] * d[i];
114+
}
115+
return ret;
116+
}
104117
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@ public interface ESVectorUtilSupport {
1818
int ipByteBit(byte[] q, byte[] d);
1919

2020
float ipFloatBit(float[] q, byte[] d);
21+
22+
float ipFloatByte(float[] q, byte[] d);
2123
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ public float ipFloatBit(float[] q, byte[] d) {
5858
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
5959
}
6060

61+
@Override
62+
public float ipFloatByte(float[] q, byte[] d) {
63+
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
64+
}
65+
6166
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
6267
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
6368

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,28 @@ public void testIpByteBit() {
3232
public void testIpFloatBit() {
3333
float[] q = new float[16];
3434
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
35-
random().nextFloat();
35+
for (int i = 0; i < q.length; i++) {
36+
q[i] = random().nextFloat();
37+
}
3638
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
3739
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
3840
}
3941

42+
public void testIpFloatByte() {
43+
float[] q = new float[16];
44+
byte[] d = new byte[16];
45+
for (int i = 0; i < q.length; i++) {
46+
q[i] = random().nextFloat();
47+
}
48+
random().nextBytes(d);
49+
50+
float expected = 0;
51+
for (int i = 0; i < q.length; i++) {
52+
expected += q[i] * d[i];
53+
}
54+
assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6);
55+
}
56+
4057
public void testBitAndCount() {
4158
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
4259
}

server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.lucene.util.BytesRef;
1313
import org.apache.lucene.util.VectorUtil;
1414
import org.elasticsearch.core.SuppressForbidden;
15+
import org.elasticsearch.simdvec.ESVectorUtil;
1516

1617
import java.nio.ByteBuffer;
1718
import java.util.List;
@@ -61,7 +62,7 @@ public int dotProduct(byte[] queryVector) {
6162

6263
@Override
6364
public double dotProduct(float[] queryVector) {
64-
throw new UnsupportedOperationException("use [int dotProduct(byte[] queryVector)] instead");
65+
return ESVectorUtil.ipFloatByte(queryVector, vectorValue);
6566
}
6667

6768
@Override
@@ -142,7 +143,11 @@ public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
142143

143144
@Override
144145
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
145-
throw new UnsupportedOperationException("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
146+
if (normalizeQueryVector) {
147+
return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
148+
}
149+
150+
return dotProduct(queryVector) / getMagnitude();
146151
}
147152

148153
@Override

server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.apache.lucene.util.VectorUtil;
1313
import org.elasticsearch.core.SuppressForbidden;
14+
import org.elasticsearch.simdvec.ESVectorUtil;
1415

1516
import java.util.List;
1617

@@ -51,12 +52,12 @@ public float getMagnitude() {
5152

5253
@Override
5354
public int dotProduct(byte[] queryVector) {
54-
return VectorUtil.dotProduct(docVector, queryVector);
55+
return VectorUtil.dotProduct(queryVector, docVector);
5556
}
5657

5758
@Override
5859
public double dotProduct(float[] queryVector) {
59-
throw new UnsupportedOperationException("use [int dotProduct(byte[] queryVector)] instead");
60+
return ESVectorUtil.ipFloatByte(queryVector, docVector);
6061
}
6162

6263
@Override
@@ -145,7 +146,11 @@ public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
145146

146147
@Override
147148
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
148-
throw new UnsupportedOperationException("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
149+
if (normalizeQueryVector) {
150+
return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
151+
}
152+
153+
return dotProduct(queryVector) / getMagnitude();
149154
}
150155

151156
@Override

server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,12 @@ public void testByteVectorClassBindings() throws IOException {
183183
for (int i = 0; i < queryVectorArray.length; i++) {
184184
queryVectorArray[i] = queryVector.get(i).floatValue();
185185
}
186-
UnsupportedOperationException uoe = expectThrows(
187-
UnsupportedOperationException.class,
188-
() -> field.getInternal().cosineSimilarity(queryVectorArray, true)
186+
assertEquals(
187+
"cosineSimilarity result is not equal to the expected value!",
188+
cosineSimilarityExpected,
189+
field.getInternal().cosineSimilarity(queryVectorArray, true),
190+
0.001
189191
);
190-
assertThat(uoe.getMessage(), containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead"));
191192

192193
// Check each function rejects query vectors with the wrong dimension
193194
IllegalArgumentException e = expectThrows(
@@ -342,11 +343,7 @@ public void testByteVsFloatSimilarity() throws IOException {
342343
switch (field.getElementType()) {
343344
case BYTE -> {
344345
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(byteVector));
345-
UnsupportedOperationException e = expectThrows(
346-
UnsupportedOperationException.class,
347-
() -> field.get().dotProduct(floatVector)
348-
);
349-
assertThat(e.getMessage(), containsString("use [int dotProduct(byte[] queryVector)] instead"));
346+
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
350347
}
351348
case FLOAT -> {
352349
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
@@ -423,14 +420,7 @@ public void testByteVsFloatSimilarity() throws IOException {
423420
switch (field.getElementType()) {
424421
case BYTE -> {
425422
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(byteVector), 0.001);
426-
UnsupportedOperationException e = expectThrows(
427-
UnsupportedOperationException.class,
428-
() -> field.get().cosineSimilarity(floatVector)
429-
);
430-
assertThat(
431-
e.getMessage(),
432-
containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead")
433-
);
423+
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);
434424
}
435425
case FLOAT -> {
436426
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);

0 commit comments

Comments
 (0)