Skip to content

Commit a272495

Browse files
committed
Correct bit * byte and bit * float script comparisons (#117404)
I goofed on the bit * byte and bit * float comparisons. Naturally, these should be bigendian and compare the dimensions with the binary ones appropriately. Additionally, I added a test to ensure that this is handled correctly. (cherry picked from commit 374c88a)
1 parent 20a78a1 commit a272495

File tree

7 files changed

+56
-27
lines changed

7 files changed

+56
-27
lines changed

docs/changelog/117404.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117404
2+
summary: Correct bit * byte and bit * float script comparisons
3+
area: Vector Search
4+
type: bug
5+
issues: []

docs/reference/vectors/vector-functions.asciidoc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ When using `bit` vectors, not all the vector functions are available. The suppor
336336
this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is
337337
the sum of the floating point values using the stored `bit` vector as a mask.
338338

339+
NOTE: When comparing `floats` and `bytes` with `bit` vectors, the `bit` vector is treated as a mask in big-endian order.
340+
For example, if the `bit` vector is `10100001` (e.g. the single byte value `161`) and its compared
341+
with array of values `[1, 2, 3, 4, 5, 6, 7, 8]` the `dotProduct` will be `1 + 3 + 8 = 16`.
342+
339343
Here is an example of using dot-product with bit vectors.
340344

341345
[source,console]

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ public static long ipByteBinByte(byte[] q, byte[] d) {
5151
/**
5252
* Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
5353
* This will return the sum of the query vector values using the document vector as a mask.
54+
* When comparing the bits with the bytes, they are done in "big endian" order. For example, if the byte vector
55+
* is [1, 2, 3, 4, 5, 6, 7, 8] and the bit vector is [0b10000000], the inner product will be 1.0.
5456
* @param q the query vector
5557
* @param d the document vector
5658
* @return the inner product of the two vectors
@@ -63,9 +65,9 @@ public static int ipByteBit(byte[] q, byte[] d) {
6365
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
6466
for (int i = 0; i < d.length; i++) {
6567
byte mask = d[i];
66-
for (int j = 0; j < Byte.SIZE; j++) {
68+
for (int j = Byte.SIZE - 1; j >= 0; j--) {
6769
if ((mask & (1 << j)) != 0) {
68-
result += q[i * Byte.SIZE + j];
70+
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
6971
}
7072
}
7173
}
@@ -75,6 +77,8 @@ public static int ipByteBit(byte[] q, byte[] d) {
7577
/**
7678
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
7779
* This will return the sum of the query vector values using the document vector as a mask.
80+
* When comparing the bits with the floats, they are done in "big endian" order. For example, if the float vector
81+
* is [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] and the bit vector is [0b10000000], the inner product will be 1.0.
7882
* @param q the query vector
7983
* @param d the document vector
8084
* @return the inner product of the two vectors
@@ -86,9 +90,9 @@ public static float ipFloatBit(float[] q, byte[] d) {
8690
float result = 0;
8791
for (int i = 0; i < d.length; i++) {
8892
byte mask = d[i];
89-
for (int j = 0; j < Byte.SIZE; j++) {
93+
for (int j = Byte.SIZE - 1; j >= 0; j--) {
9094
if ((mask & (1 << j)) != 0) {
91-
result += q[i * Byte.SIZE + j];
95+
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
9296
}
9397
}
9498
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
2121
static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
2222
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();
2323

24+
public void testIpByteBit() {
25+
byte[] q = new byte[16];
26+
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
27+
random().nextBytes(q);
28+
int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
29+
assertEquals(expected, ESVectorUtil.ipByteBit(q, d));
30+
}
31+
32+
public void testIpFloatBit() {
33+
float[] q = new float[16];
34+
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
35+
random().nextFloat();
36+
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
37+
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
38+
}
39+
2440
public void testBitAndCount() {
2541
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
2642
}

modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ setup:
108108
capabilities:
109109
- method: POST
110110
path: /_search
111-
capabilities: [ byte_float_bit_dot_product ]
111+
capabilities: [ byte_float_bit_dot_product_with_bugfix ]
112112
reason: Capability required to run test
113113
- do:
114114
catch: bad_request
@@ -399,7 +399,7 @@ setup:
399399
capabilities:
400400
- method: POST
401401
path: /_search
402-
capabilities: [ byte_float_bit_dot_product ]
402+
capabilities: [ byte_float_bit_dot_product_with_bugfix ]
403403
test_runner_features: [capabilities, close_to]
404404
reason: Capability required to run test
405405
- do:
@@ -419,13 +419,13 @@ setup:
419419
- match: { hits.total: 3 }
420420

421421
- match: {hits.hits.0._id: "2"}
422-
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}
422+
- close_to: {hits.hits.0._score: {value: 33.78, error: 0.01}}
423423

424424
- match: {hits.hits.1._id: "3"}
425-
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}
425+
- close_to: {hits.hits.1._score:{value: 22.579, error: 0.01}}
426426

427427
- match: {hits.hits.2._id: "1"}
428-
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
428+
- close_to: {hits.hits.2._score: {value: 11.919, error: 0.01}}
429429

430430
- do:
431431
headers:
@@ -444,20 +444,20 @@ setup:
444444
- match: { hits.total: 3 }
445445

446446
- match: {hits.hits.0._id: "2"}
447-
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}
447+
- close_to: {hits.hits.0._score: {value: 33.78, error: 0.01}}
448448

449449
- match: {hits.hits.1._id: "3"}
450-
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}
450+
- close_to: {hits.hits.1._score:{value: 22.579, error: 0.01}}
451451

452452
- match: {hits.hits.2._id: "1"}
453-
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
453+
- close_to: {hits.hits.2._score: {value: 11.919, error: 0.01}}
454454
---
455455
"Dot product with byte":
456456
- requires:
457457
capabilities:
458458
- method: POST
459459
path: /_search
460-
capabilities: [ byte_float_bit_dot_product ]
460+
capabilities: [ byte_float_bit_dot_product_with_bugfix ]
461461
test_runner_features: capabilities
462462
reason: Capability required to run test
463463
- do:
@@ -476,14 +476,14 @@ setup:
476476

477477
- match: { hits.total: 3 }
478478

479-
- match: {hits.hits.0._id: "1"}
480-
- match: {hits.hits.0._score: 248}
479+
- match: {hits.hits.0._id: "3"}
480+
- match: {hits.hits.0._score: 415}
481481

482-
- match: {hits.hits.1._id: "2"}
483-
- match: {hits.hits.1._score: 136}
482+
- match: {hits.hits.1._id: "1"}
483+
- match: {hits.hits.1._score: 168}
484484

485-
- match: {hits.hits.2._id: "3"}
486-
- match: {hits.hits.2._score: 20}
485+
- match: {hits.hits.2._id: "2"}
486+
- match: {hits.hits.2._score: 126}
487487

488488
- do:
489489
headers:
@@ -501,11 +501,11 @@ setup:
501501

502502
- match: { hits.total: 3 }
503503

504-
- match: {hits.hits.0._id: "1"}
505-
- match: {hits.hits.0._score: 248}
504+
- match: {hits.hits.0._id: "3"}
505+
- match: {hits.hits.0._score: 415}
506506

507-
- match: {hits.hits.1._id: "2"}
508-
- match: {hits.hits.1._score: 136}
507+
- match: {hits.hits.1._id: "1"}
508+
- match: {hits.hits.1._score: 168}
509509

510-
- match: {hits.hits.2._id: "3"}
511-
- match: {hits.hits.2._score: 20}
510+
- match: {hits.hits.2._id: "2"}
511+
- match: {hits.hits.2._score: 126}

server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ private SearchCapabilities() {}
2727
/** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */
2828
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
2929
/** Support Byte and Float with Bit dot product. */
30-
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product";
30+
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product_with_bugfix";
3131
/** Support docvalue_fields parameter for `dense_vector` field. */
3232
private static final String DENSE_VECTOR_DOCVALUE_FIELDS = "dense_vector_docvalue_fields";
3333
/** Support kql query. */

server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ public void testBitVectorClassBindingsDotProduct() throws IOException {
263263
function = new DotProduct(scoreScript, floatQueryVector, fieldName);
264264
assertEquals(
265265
"dotProduct result is not equal to the expected value!",
266-
0.42f + 0f + 1f - 1f - 0.42f,
266+
-1.4f + 0.42f + 0f + 1f - 1f,
267267
function.dotProduct(),
268268
0.001
269269
);

0 commit comments

Comments
 (0)