Skip to content

Commit 616b390

Browse files
[8.x] Add support for bitwise inner-product in painless (#116082) (#116285)
* Add support for bitwise inner-product in painless (#116082) This adds bitwise inner product to painless. The idea here is: - For two bit arrays, which we determine to be a byte array whose dimensions match `dense_vector.dim/8`, we simply return bitwise `&` - For a stored bit array (remember, with `dense_vector.dim/8` bytes), sum up the provided byte or float array using the bit array as a mask. This is effectively supporting asynchronous quantization. A prime example of how this works is: https://github.com/cohere-ai/BinaryVectorDB Basically, you do your initial search against the binary space and then rerank with a differently quantized vector allowing for more information without additional storage space. closes: #111232 * removing unnecessary task adjustment --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent a4d1abb commit 616b390

File tree

13 files changed

+548
-15
lines changed

13 files changed

+548
-15
lines changed

docs/changelog/116082.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 116082
2+
summary: Add support for bitwise inner-product in painless
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

docs/reference/vectors/vector-functions.asciidoc

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ This is the list of available vector functions and vector access methods:
1616
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
1717
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude
1818

19-
NOTE: The `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
19+
NOTE: The `cosineSimilarity` function is not supported for `bit` vectors.
2020

2121
NOTE: The recommended way to access dense vectors is through the
2222
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
@@ -332,6 +332,92 @@ When using `bit` vectors, not all the vector functions are available. The suppor
332332
* <<vector-functions-hamming,`hamming`>> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors
333333
* <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance, this is simply the `hamming` distance
334334
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
335+
* <<vector-functions-dot-product,`dotProduct`>> – calculates dot product. When comparing two `bit` vectors,
336+
this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is
337+
the sum of the floating point values using the stored `bit` vector as a mask.
335338

336-
Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
339+
Here is an example of using dot-product with bit vectors.
340+
341+
[source,console]
342+
--------------------------------------------------
343+
PUT my-index-bit-vectors
344+
{
345+
"mappings": {
346+
"properties": {
347+
"my_dense_vector": {
348+
"type": "dense_vector",
349+
"index": false,
350+
"element_type": "bit",
351+
"dims": 40 <1>
352+
}
353+
}
354+
}
355+
}
356+
357+
PUT my-index-bit-vectors/_doc/1
358+
{
359+
"my_dense_vector": [8, 5, -15, 1, -7] <2>
360+
}
361+
362+
PUT my-index-bit-vectors/_doc/2
363+
{
364+
"my_dense_vector": [-1, 115, -3, 4, -128]
365+
}
366+
367+
PUT my-index-bit-vectors/_doc/3
368+
{
369+
"my_dense_vector": [2, 18, -5, 0, -124]
370+
}
371+
372+
POST my-index-bit-vectors/_refresh
373+
--------------------------------------------------
374+
// TEST[continued]
375+
<1> The number of dimensions or bits for the `bit` vector.
376+
<2> This vector represents 5 bytes, or `5 * 8 = 40` bits, which equals the configured dimensions
377+
378+
[source,console]
379+
--------------------------------------------------
380+
GET my-index-bit-vectors/_search
381+
{
382+
"query": {
383+
"script_score": {
384+
"query" : {
385+
"match_all": {}
386+
},
387+
"script": {
388+
"source": "dotProduct(params.query_vector, 'my_dense_vector')",
389+
"params": {
390+
"query_vector": [8, 5, -15, 1, -7] <1>
391+
}
392+
}
393+
}
394+
}
395+
}
396+
--------------------------------------------------
397+
// TEST[continued]
398+
<1> This vector is 40 bits, and thus will compute a bitwise `&` operation with the stored vectors.
399+
400+
[source,console]
401+
--------------------------------------------------
402+
GET my-index-bit-vectors/_search
403+
{
404+
"query": {
405+
"script_score": {
406+
"query" : {
407+
"match_all": {}
408+
},
409+
"script": {
410+
"source": "dotProduct(params.query_vector, 'my_dense_vector')",
411+
"params": {
412+
"query_vector": [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67] <1>
413+
}
414+
}
415+
}
416+
}
417+
}
418+
--------------------------------------------------
419+
// TEST[continued]
420+
<1> This vector is 40 individual dimensions, and thus will sum the floating point values using the stored `bit` vector as a mask.
421+
422+
Currently, the `cosineSimilarity` function is not supported for `bit` vectors.
337423

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,36 @@
99

1010
package org.elasticsearch.simdvec;
1111

12+
import org.apache.lucene.util.BitUtil;
13+
import org.apache.lucene.util.Constants;
1214
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
1315
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
1416

17+
import java.lang.invoke.MethodHandle;
18+
import java.lang.invoke.MethodHandles;
19+
import java.lang.invoke.MethodType;
20+
1521
import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;
1622

1723
public class ESVectorUtil {
1824

25+
private static final MethodHandle BIT_COUNT_MH;
26+
static {
27+
try {
28+
// For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
29+
// On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
30+
// compared to Integer::bitCount. While Long::bitCount is optimal on x64. See
31+
// https://bugs.openjdk.org/browse/JDK-8336000
32+
BIT_COUNT_MH = Constants.OS_ARCH.equals("aarch64")
33+
? MethodHandles.lookup()
34+
.findStatic(ESVectorUtil.class, "andBitCountInt", MethodType.methodType(int.class, byte[].class, byte[].class))
35+
: MethodHandles.lookup()
36+
.findStatic(ESVectorUtil.class, "andBitCountLong", MethodType.methodType(int.class, byte[].class, byte[].class));
37+
} catch (NoSuchMethodException | IllegalAccessException e) {
38+
throw new AssertionError(e);
39+
}
40+
}
41+
1942
private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
2043

2144
public static long ipByteBinByte(byte[] q, byte[] d) {
@@ -24,4 +47,103 @@ public static long ipByteBinByte(byte[] q, byte[] d) {
2447
}
2548
return IMPL.ipByteBinByte(q, d);
2649
}
50+
51+
/**
52+
* Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
53+
* This will return the sum of the query vector values using the document vector as a mask.
54+
* @param q the query vector
55+
* @param d the document vector
56+
* @return the inner product of the two vectors
57+
*/
58+
public static int ipByteBit(byte[] q, byte[] d) {
59+
if (q.length != d.length * Byte.SIZE) {
60+
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
61+
}
62+
int result = 0;
63+
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
64+
for (int i = 0; i < d.length; i++) {
65+
byte mask = d[i];
66+
for (int j = 0; j < Byte.SIZE; j++) {
67+
if ((mask & (1 << j)) != 0) {
68+
result += q[i * Byte.SIZE + j];
69+
}
70+
}
71+
}
72+
return result;
73+
}
74+
75+
/**
76+
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
77+
* This will return the sum of the query vector values using the document vector as a mask.
78+
* @param q the query vector
79+
* @param d the document vector
80+
* @return the inner product of the two vectors
81+
*/
82+
public static float ipFloatBit(float[] q, byte[] d) {
83+
if (q.length != d.length * Byte.SIZE) {
84+
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
85+
}
86+
float result = 0;
87+
for (int i = 0; i < d.length; i++) {
88+
byte mask = d[i];
89+
for (int j = 0; j < Byte.SIZE; j++) {
90+
if ((mask & (1 << j)) != 0) {
91+
result += q[i * Byte.SIZE + j];
92+
}
93+
}
94+
}
95+
return result;
96+
}
97+
98+
/**
99+
* AND bit count computed over signed bytes.
100+
* Copied from Lucene's XOR implementation
101+
* @param a bytes containing a vector
102+
* @param b bytes containing another vector, of the same dimension
103+
* @return the value of the AND bit count of the two vectors
104+
*/
105+
public static int andBitCount(byte[] a, byte[] b) {
106+
if (a.length != b.length) {
107+
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
108+
}
109+
try {
110+
return (int) BIT_COUNT_MH.invokeExact(a, b);
111+
} catch (Throwable e) {
112+
if (e instanceof Error err) {
113+
throw err;
114+
} else if (e instanceof RuntimeException re) {
115+
throw re;
116+
} else {
117+
throw new RuntimeException(e);
118+
}
119+
}
120+
}
121+
122+
/** AND bit count striding over 4 bytes at a time. */
123+
static int andBitCountInt(byte[] a, byte[] b) {
124+
int distance = 0, i = 0;
125+
// limit to number of int values in the array iterating by int byte views
126+
for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
127+
distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) & (int) BitUtil.VH_NATIVE_INT.get(b, i));
128+
}
129+
// tail:
130+
for (; i < a.length; i++) {
131+
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
132+
}
133+
return distance;
134+
}
135+
136+
/** AND bit count striding over 8 bytes at a time**/
137+
static int andBitCountLong(byte[] a, byte[] b) {
138+
int distance = 0, i = 0;
139+
// limit to number of long values in the array iterating by long byte views
140+
for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
141+
distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) & (long) BitUtil.VH_NATIVE_LONG.get(b, i));
142+
}
143+
// tail:
144+
for (; i < a.length; i++) {
145+
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
146+
}
147+
return distance;
148+
}
27149
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
2121
static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
2222
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();
2323

24+
public void testBitAndCount() {
25+
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
26+
}
27+
2428
public void testIpByteBinInvariants() {
2529
int iterations = atLeast(10);
2630
for (int i = 0; i < iterations; i++) {
@@ -41,6 +45,23 @@ interface IpByteBin {
4145
long apply(byte[] q, byte[] d);
4246
}
4347

48+
interface BitOps {
49+
long apply(byte[] q, byte[] d);
50+
}
51+
52+
void testBasicBitAndImpl(BitOps bitAnd) {
53+
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 }));
54+
assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 }));
55+
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 }));
56+
assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 }));
57+
byte[] a = new byte[31];
58+
byte[] b = new byte[31];
59+
random().nextBytes(a);
60+
random().nextBytes(b);
61+
int expected = scalarBitAnd(a, b);
62+
assertEquals(expected, bitAnd.apply(a, b));
63+
}
64+
4465
void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
4566
assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
4667
assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
@@ -115,6 +136,14 @@ static int scalarIpByteBin(byte[] q, byte[] d) {
115136
return res;
116137
}
117138

139+
static int scalarBitAnd(byte[] a, byte[] b) {
140+
int res = 0;
141+
for (int i = 0; i < a.length; i++) {
142+
res += Integer.bitCount((a[i] & b[i]) & 0xFF);
143+
}
144+
return res;
145+
}
146+
118147
public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
119148
int res = 0;
120149
for (int j = 0; j < length; j++) {

modules/lang-painless/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ tasks.named("dependencyLicenses").configure {
5353
restResources {
5454
restApi {
5555
include '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'bulk', 'update',
56-
'scripts_painless_execute', 'put_script', 'delete_script'
56+
'scripts_painless_execute', 'put_script', 'delete_script', 'capabilities'
5757
}
5858
}
5959

0 commit comments

Comments
 (0)