Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/116082.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116082
summary: Add support for bitwise inner-product in painless
area: Vector Search
type: enhancement
issues: []
5 changes: 4 additions & 1 deletion docs/reference/vectors/vector-functions.asciidoc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 19 we also say that dot_product is not supported for bit vectors.

Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ When using `bit` vectors, not all the vector functions are available. The suppor
* <<vector-functions-hamming,`hamming`>> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors
* <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance, this is simply the `hamming` distance
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
* <<vector-functions-dot-product,`dotProduct`>> – calculates dot product. When comparing two `bit` vectors,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we can add that queryVector can be byte[] (of the same dims as docs or dims *8), or also can be a string, and can be of float[]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

this is the sum of the bitwise AND of the two vectors. If providing `float[]` as a query vector, the `dotProduct` is
the sum of the floating point values using the stored `bit` vector as a mask.

Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
Currently, the `cosineSimilarity` function is not supported for `bit` vectors.

101 changes: 101 additions & 0 deletions libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@

package org.elasticsearch.simdvec;

import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;

import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;

public class ESVectorUtil {

/**
* For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
* On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
* compared to Integer::bitCount. While Long::bitCount is optimal on x64. See
* https://bugs.openjdk.org/browse/JDK-8336000
*/
static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64");

private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();

public static long ipByteBinByte(byte[] q, byte[] d) {
Expand All @@ -24,4 +34,95 @@ public static long ipByteBinByte(byte[] q, byte[] d) {
}
return IMPL.ipByteBinByte(q, d);
}

/**
* Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static int ipByteBit(byte[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
int result = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}

/**
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static float ipFloatBit(float[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
float result = 0;
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}

/**
* AND bit count computed over signed bytes.
*
* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the AND bit count of the two vectors
*/
public static int andBitCount(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
if (XOR_BIT_COUNT_STRIDE_AS_INT) {
return andBitCountInt(a, b);
} else {
return andBitCountLong(a, b);
}
}

/** AND bit count striding over 4 bytes at a time. */
static int andBitCountInt(byte[] a, byte[] b) {
int distance = 0, i = 0;
for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) & (int) BitUtil.VH_NATIVE_INT.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
}
return distance;
}

/** AND bit count striding over 8 bytes at a time. */
static int andBitCountLong(byte[] a, byte[] b) {
int distance = 0, i = 0;
for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) & (long) BitUtil.VH_NATIVE_LONG.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the tail be done with a single Long.bitCount call, if using a mask based on the number of remaining bytes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly? But I didn't want to bother with over optimizing. Especially since these methods are effectively copy-pastes of what exists in Lucene for xor (just changing to &).

}
return distance;
}
}
2 changes: 1 addition & 1 deletion modules/lang-painless/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ tasks.named("dependencyLicenses").configure {
restResources {
restApi {
include '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'bulk', 'update',
'scripts_painless_execute', 'put_script', 'delete_script'
'scripts_painless_execute', 'put_script', 'delete_script', 'capabilities'
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,36 +102,6 @@ setup:
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}

---
"Dot Product is not supported":
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"

---
"Cosine Similarity is not supported":
- do:
Expand Down Expand Up @@ -388,3 +358,119 @@ setup:

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"Dot product with float":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ byte_float_bit_dot_product ]
test_runner_features: [capabilities, close_to]
reason: Capability required to run test
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}

- match: {hits.hits.2._id: "1"}
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'indexed_vector')"
params:
query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}

- match: {hits.hits.2._id: "1"}
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
---
"Dot product with byte":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ byte_float_bit_dot_product ]
test_runner_features: capabilities
reason: Capability required to run test
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "1"}
- match: {hits.hits.0._score: 248}

- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1._score: 136}

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 20}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'indexed_vector')"
params:
query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "1"}
- match: {hits.hits.0._score: 248}

- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1._score: 136}

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 20}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ private SearchCapabilities() {}
private static final String RANGE_REGEX_INTERVAL_QUERY_CAPABILITY = "range_regexp_interval_queries";
/** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
/** Support Byte and Float with Bit dot product. */
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product";

public static final Set<String> CAPABILITIES = Set.of(
RANGE_REGEX_INTERVAL_QUERY_CAPABILITY,
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY,
BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY
);
}
Loading