Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,13 @@ public void transposeHalfByteLegacy(Blackhole bh) {
bh.consume(packed);
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void transposeHalfBytePanama(Blackhole bh) {
for (int i = 0; i < numVectors; i++) {
BQSpaceUtils.transposeHalfByte(qVectors[i], packed);
bh.consume(packed);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -381,4 +381,22 @@ public static void packAsBinary(int[] vector, byte[] packed) {
}
IMPL.packAsBinary(vector, packed);
}

/**
* The idea here is to organize the query vector bits such that the first bit
* of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second,
* third, and fourth bits are in the second, third, and fourth set of dimensions bits,
* respectively. This allows for direct bitwise comparisons with the stored index vectors through
* summing the bitwise results with the relative required bit shifts.
*
* @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
* @param quantQueryByte the byte array to store the transposed query vector.
*
**/
public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
if (quantQueryByte.length * Byte.SIZE < 4 * q.length) {
throw new IllegalArgumentException("packed array is too small: " + quantQueryByte.length * Byte.SIZE + " < " + 4 * q.length);
}
IMPL.transposeHalfByte(q, quantQueryByte);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -353,4 +353,54 @@ public static void packAsBinaryImpl(int[] vector, byte[] packed) {
}
packed[index] = result;
}

@Override
public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
transposeHalfByteImpl(q, quantQueryByte);
}

public static void transposeHalfByteImpl(int[] q, byte[] quantQueryByte) {
int limit = q.length - 7;
int i = 0;
int index = 0;
for (; i < limit; i += 8, index++) {
assert q[i] >= 0 && q[i] <= 15;
assert q[i + 1] >= 0 && q[i + 1] <= 15;
assert q[i + 2] >= 0 && q[i + 2] <= 15;
assert q[i + 3] >= 0 && q[i + 3] <= 15;
assert q[i + 4] >= 0 && q[i + 4] <= 15;
assert q[i + 5] >= 0 && q[i + 5] <= 15;
assert q[i + 6] >= 0 && q[i + 6] <= 15;
assert q[i + 7] >= 0 && q[i + 7] <= 15;
int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i
+ 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1);
int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1)
<< 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1);
int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1)
<< 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1);
int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4
| ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1);
quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
}
if (i == q.length) {
return; // all done
}
int lowerByte = 0;
int lowerMiddleByte = 0;
int upperMiddleByte = 0;
int upperByte = 0;
for (int j = 7; i < q.length; j--, i++) {
lowerByte |= (q[i] & 1) << j;
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
upperByte |= ((q[i] >> 3) & 1) << j;
}
quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,6 @@ void soarDistanceBulk(
);

void packAsBinary(int[] vector, byte[] packed);

void transposeHalfByte(int[] q, byte[] quantQueryByte);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.util.Constants;

import static jdk.incubator.vector.VectorOperators.ADD;
import static jdk.incubator.vector.VectorOperators.ASHR;
import static jdk.incubator.vector.VectorOperators.LSHL;
import static jdk.incubator.vector.VectorOperators.MAX;
import static jdk.incubator.vector.VectorOperators.MIN;
Expand Down Expand Up @@ -1021,4 +1022,104 @@ private void packAsBinary128(int[] vector, byte[] packed) {
}
packed[index] = result;
}

@Override
public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
// 128 / 32 == 4
if (q.length >= 8 && HAS_FAST_INTEGER_VECTORS) {
if (VECTOR_BITSIZE >= 256) {
transposeHalfByte256(q, quantQueryByte);
return;
} else if (VECTOR_BITSIZE == 128) {
transposeHalfByte128(q, quantQueryByte);
return;
}
}
DefaultESVectorUtilSupport.transposeHalfByteImpl(q, quantQueryByte);
}

private void transposeHalfByte256(int[] q, byte[] quantQueryByte) {
final int limit = INT_SPECIES_256.loopBound(q.length);
int i = 0;
int index = 0;
for (; i < limit; i += INT_SPECIES_256.length(), index++) {
IntVector v = IntVector.fromArray(INT_SPECIES_256, q, i);

int lowerByte = v.and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
int lowerMiddleByte = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
int upperMiddleByte = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
int upperByte = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);

quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;

}
if (i == q.length) {
return; // all done
}
int lowerByte = 0;
int lowerMiddleByte = 0;
int upperMiddleByte = 0;
int upperByte = 0;
for (int j = 7; i < q.length; j--, i++) {
lowerByte |= (q[i] & 1) << j;
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
upperByte |= ((q[i] >> 3) & 1) << j;
}
quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
}

private void transposeHalfByte128(int[] q, byte[] quantQueryByte) {
final int limit = INT_SPECIES_128.loopBound(q.length) - INT_SPECIES_128.length();
int i = 0;
int index = 0;
for (; i < limit; i += 2 * INT_SPECIES_128.length(), index++) {
IntVector v = IntVector.fromArray(INT_SPECIES_128, q, i);

var lowerByteHigh = v.and(1).lanewise(LSHL, HIGH_SHIFTS_128);
var lowerMiddleByteHigh = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
var upperMiddleByteHigh = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
var upperByteHigh = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, HIGH_SHIFTS_128);

v = IntVector.fromArray(INT_SPECIES_128, q, i + INT_SPECIES_128.length());
var lowerByteLow = v.and(1).lanewise(LSHL, LOW_SHIFTS_128);
var lowerMiddleByteLow = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, LOW_SHIFTS_128);
var upperMiddleByteLow = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, LOW_SHIFTS_128);
var upperByteLow = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, LOW_SHIFTS_128);

int lowerByte = lowerByteHigh.lanewise(OR, lowerByteLow).reduceLanes(OR);
int lowerMiddleByte = lowerMiddleByteHigh.lanewise(OR, lowerMiddleByteLow).reduceLanes(OR);
int upperMiddleByte = upperMiddleByteHigh.lanewise(OR, upperMiddleByteLow).reduceLanes(OR);
int upperByte = upperByteHigh.lanewise(OR, upperByteLow).reduceLanes(OR);

quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;

}
if (i == q.length) {
return; // all done
}
int lowerByte = 0;
int lowerMiddleByte = 0;
int upperMiddleByte = 0;
int upperByte = 0;
for (int j = 7; i < q.length; j--, i++) {
lowerByte |= (q[i] & 1) << j;
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
upperByte |= ((q[i] >> 3) & 1) << j;
}
quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,20 @@ public void testPackAsBinary() {
assertArrayEquals(packedLegacy, packed);
}

public void testTransposeHalfByte() {
int dims = randomIntBetween(16, 2048);
int[] toPack = new int[dims];
for (int i = 0; i < dims; i++) {
toPack[i] = randomInt(15);
}
int length = 4 * BQVectorUtils.discretize(dims, 64) / 8;
byte[] packed = new byte[length];
byte[] packedLegacy = new byte[length];
defaultedProvider.getVectorUtilSupport().transposeHalfByte(toPack, packedLegacy);
defOrPanamaProvider.getVectorUtilSupport().transposeHalfByte(toPack, packed);
assertArrayEquals(packedLegacy, packed);
}

private float[] generateRandomVector(int size) {
float[] vector = new float[size];
for (int i = 0; i < size; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
*/
package org.elasticsearch.index.codec.vectors;

import org.elasticsearch.simdvec.ESVectorUtil;

/** Utility class for quantization calculations */
public class BQSpaceUtils {

Expand Down Expand Up @@ -117,48 +119,7 @@ public static void transposeHalfByteLegacy(byte[] q, byte[] quantQueryByte) {
* @param quantQueryByte the byte array to store the transposed query vector
* */
public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
int limit = q.length - 7;
int i = 0;
int index = 0;
for (; i < limit; i += 8, index++) {
assert q[i] >= 0 && q[i] <= 15;
assert q[i + 1] >= 0 && q[i + 1] <= 15;
assert q[i + 2] >= 0 && q[i + 2] <= 15;
assert q[i + 3] >= 0 && q[i + 3] <= 15;
assert q[i + 4] >= 0 && q[i + 4] <= 15;
assert q[i + 5] >= 0 && q[i + 5] <= 15;
assert q[i + 6] >= 0 && q[i + 6] <= 15;
assert q[i + 7] >= 0 && q[i + 7] <= 15;
int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i
+ 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1);
int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1)
<< 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1);
int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1)
<< 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1);
int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4
| ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1);
quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
}
if (i == q.length) {
return; // all done
}
int lowerByte = 0;
int lowerMiddleByte = 0;
int upperMiddleByte = 0;
int upperByte = 0;
for (int j = 7; i < q.length; j--, i++) {
lowerByte |= (q[i] & 1) << j;
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
upperByte |= ((q[i] >> 3) & 1) << j;
}
quantQueryByte[index] = (byte) lowerByte;
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
ESVectorUtil.transposeHalfByte(q, quantQueryByte);
}

/**
Expand Down