diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java new file mode 100644 index 0000000000000..ce2341f3442ff --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.benchmark.vector; + +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +// first iteration is complete garbage, so make sure we really warmup +@Warmup(iterations = 4, time = 1) +// real iterations. not useful to spend tons of time here, better to fork more +@Measurement(iterations = 5, time = 1) +// engage some noise reduction +@Fork(value = 1) +public class TransposeHalfByteBenchmark { + + static { + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Param({ "384", "782", "1024" }) + int dims; + + int length; + + int numVectors = 1000; + + int[][] qVectors; + byte[] packed; + + @Setup + public void setup() throws IOException { + Random random = new Random(123); + + this.length = 4 * BQVectorUtils.discretize(dims, 64) / 8; + this.packed = new byte[length]; + + qVectors = new int[numVectors][dims]; + for (int[] qVector : qVectors) { + for (int i = 0; i < dims; i++) { + qVector[i] = random.nextInt(16); + } + } + } + + @Benchmark + public void transposeHalfByte(Blackhole bh) { + for (int i = 0; i < numVectors; i++) { + BQSpaceUtils.transposeHalfByte(qVectors[i], packed); + bh.consume(packed); + } + } + + @Benchmark + public void transposeHalfByteLegacy(Blackhole bh) { + for (int i = 0; i < numVectors; i++) { + BQSpaceUtils.transposeHalfByteLegacy(qVectors[i], packed); + bh.consume(packed); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java index 99ead8334c21f..06c96e5a2c176 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java @@ -37,6 +37,57 @@ public class BQSpaceUtils { * @param quantQueryByte the byte array to store the transposed query vector */ public static void transposeHalfByte(byte[] q, byte[] quantQueryByte) { + int limit = q.length - 7; + int i = 0; + int index = 0; + for (; i < limit; i += 8, index++) { + assert q[i] >= 0 && q[i] <= 15; + assert q[i + 1] >= 0 && q[i + 1] <= 15; + assert q[i + 2] >= 0 && q[i + 2] <= 15; + assert q[i + 3] >= 0 && q[i + 3] <= 15; + assert q[i + 4] >= 0 && q[i + 4] <= 15; + assert q[i + 5] >= 0 && q[i + 5] <= 15; + assert q[i + 6] >= 0 && q[i + 6] <= 15; + assert q[i + 7] >= 0 && q[i + 7] <= 15; + int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i + + 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1); + int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1) + << 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1); + int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1) + << 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1); + int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4 + | ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1); + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } + if (i == q.length) { + return; // all done + } + int lowerByte = 0; + int lowerMiddleByte = 0; + int upperMiddleByte = 0; + int upperByte = 0; + for (int j = 7; i < q.length; j--, i++) { + lowerByte |= (q[i] & 1) << j; + lowerMiddleByte |= ((q[i] >> 1) & 1) << j; + upperMiddleByte |= ((q[i] >> 2) & 1) << j; + upperByte |= ((q[i] >> 3) & 1) << j; + } + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } + + /** + * Same as {@link #transposeHalfByte(byte[], byte[])} but with more readable but slower code. + * + * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15 + * @param quantQueryByte the byte array to store the transposed query vector + * */ + public static void transposeHalfByteLegacy(byte[] q, byte[] quantQueryByte) { for (int i = 0; i < q.length;) { assert q[i] >= 0 && q[i] <= 15; int lowerByte = 0; @@ -66,6 +117,57 @@ public static void transposeHalfByte(byte[] q, byte[] quantQueryByte) { * @param quantQueryByte the byte array to store the transposed query vector * */ public static void transposeHalfByte(int[] q, byte[] quantQueryByte) { + int limit = q.length - 7; + int i = 0; + int index = 0; + for (; i < limit; i += 8, index++) { + assert q[i] >= 0 && q[i] <= 15; + assert q[i + 1] >= 0 && q[i + 1] <= 15; + assert q[i + 2] >= 0 && q[i + 2] <= 15; + assert q[i + 3] >= 0 && q[i + 3] <= 15; + assert q[i + 4] >= 0 && q[i + 4] <= 15; + assert q[i + 5] >= 0 && q[i + 5] <= 15; + assert q[i + 6] >= 0 && q[i + 6] <= 15; + assert q[i + 7] >= 0 && q[i + 7] <= 15; + int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i + + 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1); + int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1) + << 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1); + int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1) + << 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1); + int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4 + | ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1); + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } + if (i == q.length) { + return; // all done + } + int lowerByte = 0; + int lowerMiddleByte = 0; + int upperMiddleByte = 0; + int upperByte = 0; + for (int j = 7; i < q.length; j--, i++) { + lowerByte |= (q[i] & 1) << j; + lowerMiddleByte |= ((q[i] >> 1) & 1) << j; + upperMiddleByte |= ((q[i] >> 2) & 1) << j; + upperByte |= ((q[i] >> 3) & 1) << j; + } + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } + + /** + * Same as {@link #transposeHalfByte(int[], byte[])} but with more readable but slower code. + * + * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15 + * @param quantQueryByte the byte array to store the transposed query vector + * */ + public static void transposeHalfByteLegacy(int[] q, byte[] quantQueryByte) { for (int i = 0; i < q.length;) { assert q[i] >= 0 && q[i] <= 15; int lowerByte = 0; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQSpaceUtilsTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQSpaceUtilsTests.java new file mode 100644 index 0000000000000..3e72ff011167e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQSpaceUtilsTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.elasticsearch.test.ESTestCase; + +public class BQSpaceUtilsTests extends ESTestCase { + + public void testIntegerTransposeHalfByte() { + int dims = randomIntBetween(16, 2048); + int[] toPack = new int[dims]; + for (int i = 0; i < dims; i++) { + toPack[i] = randomInt(15); + } + int length = 4 * BQVectorUtils.discretize(dims, 64) / 8; + byte[] packed = new byte[length]; + byte[] packedLegacy = new byte[length]; + BQSpaceUtils.transposeHalfByteLegacy(toPack, packedLegacy); + BQSpaceUtils.transposeHalfByte(toPack, packed); + assertArrayEquals(packedLegacy, packed); + } + + public void testByteTransposeHalfByte() { + int dims = randomIntBetween(16, 2048); + byte[] toPack = new byte[dims]; + for (int i = 0; i < dims; i++) { + toPack[i] = (byte) randomInt(15); + } + int length = 4 * BQVectorUtils.discretize(dims, 64) / 8; + byte[] packed = new byte[length]; + byte[] packedLegacy = new byte[length]; + BQSpaceUtils.transposeHalfByteLegacy(toPack, packedLegacy); + BQSpaceUtils.transposeHalfByte(toPack, packed); + assertArrayEquals(packedLegacy, packed); + } +}