From d1b198bf2650d7c302e406deb6f9de2a70756ef3 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Thu, 14 Aug 2025 15:59:02 +0100 Subject: [PATCH] Vectorize BQSpaceUtils#transposeHalfByte --- .../vector/TransposeHalfByteBenchmark.java | 9 ++ .../elasticsearch/simdvec/ESVectorUtil.java | 18 ++++ .../DefaultESVectorUtilSupport.java | 50 +++++++++ .../vectorization/ESVectorUtilSupport.java | 2 + .../PanamaESVectorUtilSupport.java | 101 ++++++++++++++++++ .../simdvec/ESVectorUtilTests.java | 14 +++ .../index/codec/vectors/BQSpaceUtils.java | 45 +------- 7 files changed, 197 insertions(+), 42 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java index ce2341f3442ff..b612e35d37292 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java @@ -83,4 +83,13 @@ public void transposeHalfByteLegacy(Blackhole bh) { bh.consume(packed); } } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void transposeHalfBytePanama(Blackhole bh) { + for (int i = 0; i < numVectors; i++) { + BQSpaceUtils.transposeHalfByte(qVectors[i], packed); + bh.consume(packed); + } + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 81fa6a959574f..c083f1c92a4fd 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -381,4 +381,22 @@ public static void packAsBinary(int[] vector, byte[] packed) { } IMPL.packAsBinary(vector, packed); } + + /** + * The idea here is to organize the query vector bits such that the first bit + * of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second, + * third, and fourth bits are in the second, third, and fourth set of dimensions bits, + * respectively. This allows for direct bitwise comparisons with the stored index vectors through + * summing the bitwise results with the relative required bit shifts. + * + * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15 + * @param quantQueryByte the byte array to store the transposed query vector. + * + **/ + public static void transposeHalfByte(int[] q, byte[] quantQueryByte) { + if (quantQueryByte.length * Byte.SIZE < 4 * q.length) { + throw new IllegalArgumentException("packed array is too small: " + quantQueryByte.length * Byte.SIZE + " < " + 4 * q.length); + } + IMPL.transposeHalfByte(q, quantQueryByte); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index f9c52a544db9c..c78970a0c8794 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -353,4 +353,54 @@ public static void packAsBinaryImpl(int[] vector, byte[] packed) { } packed[index] = result; } + + @Override + public void transposeHalfByte(int[] q, byte[] quantQueryByte) { + transposeHalfByteImpl(q, quantQueryByte); + } + + public static void transposeHalfByteImpl(int[] q, byte[] quantQueryByte) { + int limit = q.length - 7; + int i = 0; + int index = 0; + for (; i < limit; i += 8, index++) { + assert q[i] >= 0 && q[i] <= 15; + assert q[i + 1] >= 0 && q[i + 1] <= 15; + assert q[i + 2] >= 0 && q[i + 2] <= 15; + assert q[i + 3] >= 0 && q[i + 3] <= 15; + assert q[i + 4] >= 0 && q[i + 4] <= 15; + assert q[i + 5] >= 0 && q[i + 5] <= 15; + assert q[i + 6] >= 0 && q[i + 6] <= 15; + assert q[i + 7] >= 0 && q[i + 7] <= 15; + int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i + + 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1); + int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1) + << 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1); + int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1) + << 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1); + int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4 + | ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1); + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } + if (i == q.length) { + return; // all done + } + int lowerByte = 0; + int lowerMiddleByte = 0; + int upperMiddleByte = 0; + int upperByte = 0; + for (int j = 7; i < q.length; j--, i++) { + lowerByte |= (q[i] & 1) << j; + lowerMiddleByte |= ((q[i] >> 1) & 1) << j; + upperMiddleByte |= ((q[i] >> 2) & 1) << j; + upperByte |= ((q[i] >> 3) & 1) << j; + } + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java index d75b9cf6cfd25..08c256051661e 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java @@ -65,4 +65,6 @@ void soarDistanceBulk( ); void packAsBinary(int[] vector, byte[] packed); + + void transposeHalfByte(int[] q, byte[] quantQueryByte); } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 33c2eca4c6152..62637a621cd0b 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -22,6 +22,7 @@ import org.apache.lucene.util.Constants; import static jdk.incubator.vector.VectorOperators.ADD; +import static jdk.incubator.vector.VectorOperators.ASHR; import static jdk.incubator.vector.VectorOperators.LSHL; import static jdk.incubator.vector.VectorOperators.MAX; import static jdk.incubator.vector.VectorOperators.MIN; @@ -1021,4 +1022,104 @@ private void packAsBinary128(int[] vector, byte[] packed) { } packed[index] = result; } + + @Override + public void transposeHalfByte(int[] q, byte[] quantQueryByte) { + // 128 / 32 == 4 + if (q.length >= 8 && HAS_FAST_INTEGER_VECTORS) { + if (VECTOR_BITSIZE >= 256) { + transposeHalfByte256(q, quantQueryByte); + return; + } else if (VECTOR_BITSIZE == 128) { + transposeHalfByte128(q, quantQueryByte); + return; + } + } + DefaultESVectorUtilSupport.transposeHalfByteImpl(q, quantQueryByte); + } + + private void transposeHalfByte256(int[] q, byte[] quantQueryByte) { + final int limit = INT_SPECIES_256.loopBound(q.length); + int i = 0; + int index = 0; + for (; i < limit; i += INT_SPECIES_256.length(), index++) { + IntVector v = IntVector.fromArray(INT_SPECIES_256, q, i); + + int lowerByte = v.and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR); + int lowerMiddleByte = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR); + int upperMiddleByte = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR); + int upperByte = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR); + + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + + } + if (i == q.length) { + return; // all done + } + int lowerByte = 0; + int lowerMiddleByte = 0; + int upperMiddleByte = 0; + int upperByte = 0; + for (int j = 7; i < q.length; j--, i++) { + lowerByte |= (q[i] & 1) << j; + lowerMiddleByte |= ((q[i] >> 1) & 1) << j; + upperMiddleByte |= ((q[i] >> 2) & 1) << j; + upperByte |= ((q[i] >> 3) & 1) << j; + } + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } + + private void transposeHalfByte128(int[] q, byte[] quantQueryByte) { + final int limit = INT_SPECIES_128.loopBound(q.length) - INT_SPECIES_128.length(); + int i = 0; + int index = 0; + for (; i < limit; i += 2 * INT_SPECIES_128.length(), index++) { + IntVector v = IntVector.fromArray(INT_SPECIES_128, q, i); + + var lowerByteHigh = v.and(1).lanewise(LSHL, HIGH_SHIFTS_128); + var lowerMiddleByteHigh = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, HIGH_SHIFTS_128); + var upperMiddleByteHigh = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, HIGH_SHIFTS_128); + var upperByteHigh = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, HIGH_SHIFTS_128); + + v = IntVector.fromArray(INT_SPECIES_128, q, i + INT_SPECIES_128.length()); + var lowerByteLow = v.and(1).lanewise(LSHL, LOW_SHIFTS_128); + var lowerMiddleByteLow = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, LOW_SHIFTS_128); + var upperMiddleByteLow = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, LOW_SHIFTS_128); + var upperByteLow = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, LOW_SHIFTS_128); + + int lowerByte = lowerByteHigh.lanewise(OR, lowerByteLow).reduceLanes(OR); + int lowerMiddleByte = lowerMiddleByteHigh.lanewise(OR, lowerMiddleByteLow).reduceLanes(OR); + int upperMiddleByte = upperMiddleByteHigh.lanewise(OR, upperMiddleByteLow).reduceLanes(OR); + int upperByte = upperByteHigh.lanewise(OR, upperByteLow).reduceLanes(OR); + + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + + } + if (i == q.length) { + return; // all done + } + int lowerByte = 0; + int lowerMiddleByte = 0; + int upperMiddleByte = 0; + int upperByte = 0; + for (int j = 7; i < q.length; j--, i++) { + lowerByte |= (q[i] & 1) << j; + lowerMiddleByte |= ((q[i] >> 1) & 1) << j; + upperMiddleByte |= ((q[i] >> 2) & 1) << j; + upperByte |= ((q[i] >> 3) & 1) << j; + } + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index b0aa4a9a45afe..24aff1107d7e7 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -370,6 +370,20 @@ public void testPackAsBinary() { assertArrayEquals(packedLegacy, packed); } + public void testTransposeHalfByte() { + int dims = randomIntBetween(16, 2048); + int[] toPack = new int[dims]; + for (int i = 0; i < dims; i++) { + toPack[i] = randomInt(15); + } + int length = 4 * BQVectorUtils.discretize(dims, 64) / 8; + byte[] packed = new byte[length]; + byte[] packedLegacy = new byte[length]; + defaultedProvider.getVectorUtilSupport().transposeHalfByte(toPack, packedLegacy); + defOrPanamaProvider.getVectorUtilSupport().transposeHalfByte(toPack, packed); + assertArrayEquals(packedLegacy, packed); + } + private float[] generateRandomVector(int size) { float[] vector = new float[size]; for (int i = 0; i < size; ++i) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java index 06c96e5a2c176..bb26357cb6990 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java @@ -19,6 +19,8 @@ */ package org.elasticsearch.index.codec.vectors; +import org.elasticsearch.simdvec.ESVectorUtil; + /** Utility class for quantization calculations */ public class BQSpaceUtils { @@ -117,48 +119,7 @@ public static void transposeHalfByteLegacy(byte[] q, byte[] quantQueryByte) { * @param quantQueryByte the byte array to store the transposed query vector * */ public static void transposeHalfByte(int[] q, byte[] quantQueryByte) { - int limit = q.length - 7; - int i = 0; - int index = 0; - for (; i < limit; i += 8, index++) { - assert q[i] >= 0 && q[i] <= 15; - assert q[i + 1] >= 0 && q[i + 1] <= 15; - assert q[i + 2] >= 0 && q[i + 2] <= 15; - assert q[i + 3] >= 0 && q[i + 3] <= 15; - assert q[i + 4] >= 0 && q[i + 4] <= 15; - assert q[i + 5] >= 0 && q[i + 5] <= 15; - assert q[i + 6] >= 0 && q[i + 6] <= 15; - assert q[i + 7] >= 0 && q[i + 7] <= 15; - int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i - + 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1); - int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1) - << 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1); - int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1) - << 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1); - int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4 - | ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1); - quantQueryByte[index] = (byte) lowerByte; - quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; - quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; - quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; - } - if (i == q.length) { - return; // all done - } - int lowerByte = 0; - int lowerMiddleByte = 0; - int upperMiddleByte = 0; - int upperByte = 0; - for (int j = 7; i < q.length; j--, i++) { - lowerByte |= (q[i] & 1) << j; - lowerMiddleByte |= ((q[i] >> 1) & 1) << j; - upperMiddleByte |= ((q[i] >> 2) & 1) << j; - upperByte |= ((q[i] >> 3) & 1) << j; - } - quantQueryByte[index] = (byte) lowerByte; - quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; - quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; - quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + ESVectorUtil.transposeHalfByte(q, quantQueryByte); } /**