diff --git a/docs/changelog/117404.yaml b/docs/changelog/117404.yaml new file mode 100644 index 0000000000000..0bab171956ca9 --- /dev/null +++ b/docs/changelog/117404.yaml @@ -0,0 +1,5 @@ +pr: 117404 +summary: Correct bit * byte and bit * float script comparisons +area: Vector Search +type: bug +issues: [] diff --git a/docs/reference/vectors/vector-functions.asciidoc b/docs/reference/vectors/vector-functions.asciidoc index 10dca8084e28a..23419e8eb12b1 100644 --- a/docs/reference/vectors/vector-functions.asciidoc +++ b/docs/reference/vectors/vector-functions.asciidoc @@ -336,6 +336,10 @@ When using `bit` vectors, not all the vector functions are available. The suppor this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is the sum of the floating point values using the stored `bit` vector as a mask. +NOTE: When comparing `floats` and `bytes` with `bit` vectors, the `bit` vector is treated as a mask in big-endian order. +For example, if the `bit` vector is `10100001` (e.g. the single byte value `161`) and its compared +with array of values `[1, 2, 3, 4, 5, 6, 7, 8]` the `dotProduct` will be `1 + 3 + 8 = 16`. + Here is an example of using dot-product with bit vectors. [source,console] diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index de2cb9042610b..2f4743a47a14a 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -51,6 +51,8 @@ public static long ipByteBinByte(byte[] q, byte[] d) { /** * Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector. * This will return the sum of the query vector values using the document vector as a mask. + * When comparing the bits with the bytes, they are done in "big endian" order. For example, if the byte vector + * is [1, 2, 3, 4, 5, 6, 7, 8] and the bit vector is [0b10000000], the inner product will be 1.0. * @param q the query vector * @param d the document vector * @return the inner product of the two vectors @@ -63,9 +65,9 @@ public static int ipByteBit(byte[] q, byte[] d) { // now combine the two vectors, summing the byte dimensions where the bit in d is `1` for (int i = 0; i < d.length; i++) { byte mask = d[i]; - for (int j = 0; j < Byte.SIZE; j++) { + for (int j = Byte.SIZE - 1; j >= 0; j--) { if ((mask & (1 << j)) != 0) { - result += q[i * Byte.SIZE + j]; + result += q[i * Byte.SIZE + Byte.SIZE - 1 - j]; } } } @@ -75,6 +77,8 @@ public static int ipByteBit(byte[] q, byte[] d) { /** * Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector. * This will return the sum of the query vector values using the document vector as a mask. + * When comparing the bits with the floats, they are done in "big endian" order. For example, if the float vector + * is [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] and the bit vector is [0b10000000], the inner product will be 1.0. * @param q the query vector * @param d the document vector * @return the inner product of the two vectors @@ -86,9 +90,9 @@ public static float ipFloatBit(float[] q, byte[] d) { float result = 0; for (int i = 0; i < d.length; i++) { byte mask = d[i]; - for (int j = 0; j < Byte.SIZE; j++) { + for (int j = Byte.SIZE - 1; j >= 0; j--) { if ((mask & (1 << j)) != 0) { - result += q[i * Byte.SIZE + j]; + result += q[i * Byte.SIZE + Byte.SIZE - 1 - j]; } } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index e9e0fd58f7638..368898b934c87 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -21,6 +21,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests { static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider(); static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider(); + public void testIpByteBit() { + byte[] q = new byte[16]; + byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) }; + random().nextBytes(q); + int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15]; + assertEquals(expected, ESVectorUtil.ipByteBit(q, d)); + } + + public void testIpFloatBit() { + float[] q = new float[16]; + byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) }; + random().nextFloat(); + float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15]; + assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6); + } + public void testBitAndCount() { testBasicBitAndImpl(ESVectorUtil::andBitCountLong); } diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml index 2ee38f849e9d4..cdd65ca0eb296 100644 --- a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml @@ -108,7 +108,7 @@ setup: capabilities: - method: POST path: /_search - capabilities: [ byte_float_bit_dot_product ] + capabilities: [ byte_float_bit_dot_product_with_bugfix ] reason: Capability required to run test - do: catch: bad_request @@ -399,7 +399,7 @@ setup: capabilities: - method: POST path: /_search - capabilities: [ byte_float_bit_dot_product ] + capabilities: [ byte_float_bit_dot_product_with_bugfix ] test_runner_features: [capabilities, close_to] reason: Capability required to run test - do: @@ -419,13 +419,13 @@ setup: - match: { hits.total: 3 } - match: {hits.hits.0._id: "2"} - - close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}} + - close_to: {hits.hits.0._score: {value: 33.78, error: 0.01}} - match: {hits.hits.1._id: "3"} - - close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}} + - close_to: {hits.hits.1._score:{value: 22.579, error: 0.01}} - match: {hits.hits.2._id: "1"} - - close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}} + - close_to: {hits.hits.2._score: {value: 11.919, error: 0.01}} - do: headers: @@ -444,20 +444,20 @@ setup: - match: { hits.total: 3 } - match: {hits.hits.0._id: "2"} - - close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}} + - close_to: {hits.hits.0._score: {value: 33.78, error: 0.01}} - match: {hits.hits.1._id: "3"} - - close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}} + - close_to: {hits.hits.1._score:{value: 22.579, error: 0.01}} - match: {hits.hits.2._id: "1"} - - close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}} + - close_to: {hits.hits.2._score: {value: 11.919, error: 0.01}} --- "Dot product with byte": - requires: capabilities: - method: POST path: /_search - capabilities: [ byte_float_bit_dot_product ] + capabilities: [ byte_float_bit_dot_product_with_bugfix ] test_runner_features: capabilities reason: Capability required to run test - do: @@ -476,14 +476,14 @@ setup: - match: { hits.total: 3 } - - match: {hits.hits.0._id: "1"} - - match: {hits.hits.0._score: 248} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0._score: 415} - - match: {hits.hits.1._id: "2"} - - match: {hits.hits.1._score: 136} + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1._score: 168} - - match: {hits.hits.2._id: "3"} - - match: {hits.hits.2._score: 20} + - match: {hits.hits.2._id: "2"} + - match: {hits.hits.2._score: 126} - do: headers: @@ -501,11 +501,11 @@ setup: - match: { hits.total: 3 } - - match: {hits.hits.0._id: "1"} - - match: {hits.hits.0._score: 248} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0._score: 415} - - match: {hits.hits.1._id: "2"} - - match: {hits.hits.1._score: 136} + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1._score: 168} - - match: {hits.hits.2._id: "3"} - - match: {hits.hits.2._score: 20} + - match: {hits.hits.2._id: "2"} + - match: {hits.hits.2._score: 126} diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 7c88eb6565bd9..65c0d7e8cc551 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -27,7 +27,7 @@ private SearchCapabilities() {} /** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */ private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source"; /** Support Byte and Float with Bit dot product. */ - private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product"; + private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product_with_bugfix"; /** Support docvalue_fields parameter for `dense_vector` field. */ private static final String DENSE_VECTOR_DOCVALUE_FIELDS = "dense_vector_docvalue_fields"; /** Support kql query. */ diff --git a/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java b/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java index 2d9caca1ba6a1..cd6c781db8c00 100644 --- a/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java +++ b/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java @@ -263,7 +263,7 @@ public void testBitVectorClassBindingsDotProduct() throws IOException { function = new DotProduct(scoreScript, floatQueryVector, fieldName); assertEquals( "dotProduct result is not equal to the expected value!", - 0.42f + 0f + 1f - 1f - 0.42f, + -1.4f + 0.42f + 0f + 1f - 1f, function.dotProduct(), 0.001 );