Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -204,19 +204,58 @@ static float ipFloatBitImpl(float[] q, byte[] d, int start) {
return acc0 + acc1 + acc2 + acc3;
}

/**
* Returns the inner product (aka dot product) between the query vector {@code q}, and the data vector {@code d}.
* <p>
* The query vector should be {@link #B_QUERY}-bit quantized and striped, so that the first {@code n} bits
* of the array are the initial bits of each of the {@code n} vector dimensions; the next {@code n}
* bits are the second bits of each of the {@code n} vector dimensions, and so on
* (this algorithm is only valid for vectors with dimensions a multiple of 8).
* The striping is usually done by {@code BQSpaceUtils.transposeHalfByte}.
* <p>
* The data vector should be single-bit quantized.
*
* <h4>Dot products with bit quantization</h4>
*
* The dot product of any vector with a bit vector is a simple selector - each query vector dimension is multiplied
* by the 0 or 1 in the corresponding data vector dimension; the result is that each dimension value
* is either kept or ignored, with the dimensions that are kept then summed together.
*
* <h4>The algorithm</h4>
*
* The transposition already applied to the query vector ensures there's a 1-to-1 correspondence
* between the data vector bits and query vector bits (see {@code BQSpaceUtils.transposeHalfByte)};
* this means we can use a bitwise {@code &} to keep only the bits of the vector elements we want to sum.
* Essentially, the data vector is used as a selector for each of the striped bits of each vector dimension
* as stored, concatenated together, in {@code q}.
* <p>
* The final dot product result can be obtained by observing that the sum of each stripe of {@code n} bits
* can be computed using the bit count of that stripe. Similar to
* <a href="https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication">long multiplication</a>,
* the result of each stripe of {@code n} bits can be added together by shifting the value {@code s} bits to the left,
* where {@code s} is the stripe number (0-3), then adding to the overall result. Any carry is handled by the add operation.
*
* @param q query vector, {@link #B_QUERY}-bit quantized and striped (see {@code BQSpaceUtils.transposeHalfByte})
* @param d data vector, 1-bit quantized
* @return inner product result
*/
public static long ipByteBinByteImpl(byte[] q, byte[] d) {
long ret = 0;
int size = d.length;
for (int i = 0; i < B_QUERY; i++) {
for (int s = 0; s < B_QUERY; s++) { // for each stripe of B_QUERY-bit quantization in q...
int r = 0;
long subRet = 0;
long stripeRet = 0;
// bitwise & the query and data vectors together, 32-bits at a time, and counting up the bits still set
for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
subRet += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(q, i * size + r) & (int) BitUtil.VH_NATIVE_INT.get(d, r));
stripeRet += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(q, s * size + r) & (int) BitUtil.VH_NATIVE_INT.get(d, r));
}
// handle any tail
// Java operations on bytes automatically extend to int, so we need to mask back down again in case it sign-extends the int
for (; r < d.length; r++) {
subRet += Integer.bitCount((q[i * size + r] & d[r]) & 0xFF);
stripeRet += Integer.bitCount((q[s * size + r] & d[r]) & 0xFF);
}
ret += subRet << i;
// shift the result of the s'th stripe s to the left and add to the result
ret += stripeRet << s;
}
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@

public interface ESVectorUtilSupport {

/**
* The number of bits in bit-quantized query vectors
*/
short B_QUERY = 4;

/**
* Compute dot product between {@code q} and {@code d}
* @param q query vector, {@link #B_QUERY}-bit quantized and striped (see {@code BQSpaceUtils.transposeHalfByte})
* @param d data vector, 1-bit quantized
*/
long ipByteBinByte(byte[] q, byte[] d);

int ipByteBit(byte[] q, byte[] d);
Expand Down