diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 3ea1326b4608..3bf3d34f4826 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -150,6 +150,8 @@ Optimizations * GITHUB#15160: Increased the size used for blocks of postings from 128 to 256. This gives a noticeable speedup to many queries. (Adrien Grand) +* GITHUB#14863: Perform scoring for 4, 7, 8 bit quantized vectors off-heap. (Kaival Parikh) + Bug Fixes --------------------- * GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java index a8eb1b945cee..4c8253fdab9f 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java @@ -54,11 +54,13 @@ static void compressBytes(byte[] raw, byte[] compressed) { private byte[] bytesA; private byte[] bytesB; private byte[] halfBytesA; + private byte[] halfBytesAPacked; private byte[] halfBytesB; private byte[] halfBytesBPacked; private float[] floatsA; private float[] floatsB; - private int expectedhalfByteDotProduct; + private int expectedHalfByteDotProduct; + private int expectedHalfByteSquareDistance; @Param({"1", "128", "207", "256", "300", "512", "702", "1024"}) int size; @@ -74,16 +76,23 @@ public void init() { random.nextBytes(bytesB); // random half byte arrays for binary methods // this means that all values must be between 0 and 15 - expectedhalfByteDotProduct = 0; + expectedHalfByteDotProduct = 0; + expectedHalfByteSquareDistance = 0; halfBytesA = new byte[size]; halfBytesB = new byte[size]; for (int i = 0; i < size; ++i) { halfBytesA[i] = (byte) random.nextInt(16); halfBytesB[i] = (byte) random.nextInt(16); - expectedhalfByteDotProduct += halfBytesA[i] * halfBytesB[i]; + expectedHalfByteDotProduct += halfBytesA[i] * halfBytesB[i]; + + int diff = halfBytesA[i] - halfBytesB[i]; + expectedHalfByteSquareDistance += diff * diff; } // pack the half byte arrays if (size % 2 == 0) { + halfBytesAPacked = new byte[(size + 1) >> 1]; + compressBytes(halfBytesA, halfBytesAPacked); + halfBytesBPacked = new byte[(size + 1) >> 1]; compressBytes(halfBytesB, halfBytesBPacked); } @@ -108,6 +117,74 @@ public float binaryCosineVector() { return VectorUtil.cosine(bytesA, bytesB); } + @Benchmark + public int binarySquareScalar() { + return VectorUtil.squareDistance(bytesA, bytesB); + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int binarySquareVector() { + return VectorUtil.squareDistance(bytesA, bytesB); + } + + @Benchmark + public int binaryHalfByteSquareScalar() { + int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int binaryHalfByteSquareVector() { + int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + public int binaryHalfByteSquareSinglePackedScalar() { + int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int binaryHalfByteSquareSinglePackedVector() { + int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + public int binaryHalfByteSquareBothPackedScalar() { + int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int binaryHalfByteSquareBothPackedVector() { + int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked); + if (v != expectedHalfByteSquareDistance) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; + } + @Benchmark public int binaryDotProductScalar() { return VectorUtil.dotProduct(bytesA, bytesB); @@ -131,14 +208,22 @@ public int binaryDotProductUint8Vector() { } @Benchmark - public int binarySquareScalar() { - return VectorUtil.squareDistance(bytesA, bytesB); + public int binaryHalfByteDotProductScalar() { + int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; } @Benchmark @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) - public int binarySquareVector() { - return VectorUtil.squareDistance(bytesA, bytesB); + public int binaryHalfByteDotProductVector() { + int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; } @Benchmark @@ -153,37 +238,39 @@ public int binarySquareUint8Vector() { } @Benchmark - public int binaryHalfByteScalar() { - return VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + public int binaryHalfByteDotProductSinglePackedScalar() { + int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; } @Benchmark @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) - public int binaryHalfByteVector() { - return VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + public int binaryHalfByteDotProductSinglePackedVector() { + int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); + } + return v; } @Benchmark - public int binaryHalfByteScalarPacked() { - if (size % 2 != 0) { - throw new RuntimeException("Size must be even for this benchmark"); - } - int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked); - if (v != expectedhalfByteDotProduct) { - throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v); + public int binaryHalfByteDotProductBothPackedScalar() { + int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); } return v; } @Benchmark @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) - public int binaryHalfByteVectorPacked() { - if (size % 2 != 0) { - throw new RuntimeException("Size must be even for this benchmark"); - } - int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked); - if (v != expectedhalfByteDotProduct) { - throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v); + public int binaryHalfByteDotProductBothPackedVector() { + int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked); + if (v != expectedHalfByteDotProduct) { + throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v); } return v; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java index 808d7b3cc882..123c18e00c08 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java @@ -37,4 +37,8 @@ private FlatVectorScorerUtil() {} public static FlatVectorsScorer getLucene99FlatVectorsScorer() { return IMPL.getLucene99FlatVectorsScorer(); } + + public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return IMPL.getLucene99ScalarQuantizedVectorsScorer(); + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java index 117521ddcc2a..80afaf5c685a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java @@ -23,6 +23,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.FloatToFloatFunction; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -245,7 +246,7 @@ public float score(int vectorOrdinal) throws IOException { values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES)); values.getSlice().readBytes(compressedVector, 0, compressedVector.length); float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal); - int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector); + int dotProduct = VectorUtil.int4DotProductSinglePacked(targetBytes, compressedVector); // For the current implementation of scalar quantization, all dotproducts should // be >= 0; assert dotProduct >= 0; @@ -301,11 +302,6 @@ public void setScoringOrdinal(int node) throws IOException { } } - @FunctionalInterface - private interface FloatToFloatFunction { - float apply(float f); - } - private static final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 0f339ecbe0a8..76c73980aef8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -18,10 +18,10 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -68,7 +68,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { final byte bits; final boolean compress; - final Lucene99ScalarQuantizedVectorScorer flatVectorScorer; + final FlatVectorsScorer flatVectorScorer; /** Constructs a format using default graph construction parameters */ public Lucene99ScalarQuantizedVectorsFormat() { @@ -115,8 +115,7 @@ public Lucene99ScalarQuantizedVectorsFormat( this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; this.compress = compress; - this.flatVectorScorer = - new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer(); } public static float calculateDefaultConfidenceInterval(int vectorDimension) { diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java index 89c813a4b93b..7f08c673a7f1 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java @@ -164,24 +164,35 @@ public int uint8DotProduct(byte[] a, byte[] b) { } @Override - public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { - assert (apacked && bpacked) == false; - if (apacked || bpacked) { - byte[] packed = apacked ? a : b; - byte[] unpacked = apacked ? b : a; - int total = 0; - for (int i = 0; i < packed.length; i++) { - byte packedByte = packed[i]; - byte unpacked1 = unpacked[i]; - byte unpacked2 = unpacked[i + packed.length]; - total += (packedByte & 0x0F) * unpacked2; - total += ((packedByte & 0xFF) >> 4) * unpacked1; - } - return total; - } + public int int4DotProduct(byte[] a, byte[] b) { return dotProduct(a, b); } + @Override + public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) { + int total = 0; + for (int i = 0; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + total += (packedByte & 0x0F) * unpacked2; + total += ((packedByte & 0xFF) >> 4) * unpacked1; + } + return total; + } + + @Override + public int int4DotProductBothPacked(byte[] a, byte[] b) { + int total = 0; + for (int i = 0; i < a.length; i++) { + byte aByte = a[i]; + byte bByte = b[i]; + total += (aByte & 0x0F) * (bByte & 0x0F); + total += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4); + } + return total; + } + @Override public float cosine(byte[] a, byte[] b) { // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. @@ -210,6 +221,42 @@ public int squareDistance(byte[] a, byte[] b) { return squareSum; } + @Override + public int int4SquareDistance(byte[] a, byte[] b) { + return squareDistance(a, b); + } + + @Override + public int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) { + int total = 0; + for (int i = 0; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + + int diff1 = (packedByte & 0x0F) - unpacked2; + int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1; + + total += diff1 * diff1 + diff2 * diff2; + } + return total; + } + + @Override + public int int4SquareDistanceBothPacked(byte[] a, byte[] b) { + int total = 0; + for (int i = 0; i < a.length; i++) { + byte aByte = a[i]; + byte bByte = b[i]; + + int diff1 = (aByte & 0x0F) - (bByte & 0x0F); + int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4); + + total += diff1 * diff1 + diff2 * diff2; + } + return total; + } + @Override public int uint8SquareDistance(byte[] a, byte[] b) { // Note: this will not overflow if dim < 2^16, since max(ubyte * ubyte) = 2^16. diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index c5e9301e9bc4..21977fa3dc77 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -19,6 +19,7 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; import org.apache.lucene.store.IndexInput; /** Default provider returning scalar implementations. */ @@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() { return DefaultFlatVectorScorer.INSTANCE; } + @Override + public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + } + @Override public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) { return new PostingDecodingUtil(input); diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java index 7190f983b4ce..7242a2501a19 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java @@ -36,18 +36,40 @@ public interface VectorUtilSupport { /** Returns the dot product computed over signed bytes. */ int dotProduct(byte[] a, byte[] b); + /** Returns the dot product computed over unsigned half-bytes, both uncompressed. */ + int int4DotProduct(byte[] a, byte[] b); + + /** Returns the dot product computed over unsigned half-bytes, one compressed. */ + int int4DotProductSinglePacked(byte[] unpacked, byte[] packed); + + /** Returns the dot product computed over unsigned half-bytes, both compressed. */ + int int4DotProductBothPacked(byte[] a, byte[] b); + /** Returns the dot product computed as though the bytes were unsigned. */ int uint8DotProduct(byte[] a, byte[] b); - /** Returns the dot product over the computed bytes, assuming the values are int4 encoded. */ - int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked); - /** Returns the cosine similarity between the two byte vectors. */ float cosine(byte[] a, byte[] b); /** Returns the sum of squared differences of the two byte vectors. */ int squareDistance(byte[] a, byte[] b); + /** + * Returns the sum of squared differences between two unsigned half-byte vectors, both + * uncompressed. + */ + int int4SquareDistance(byte[] a, byte[] b); + + /** + * Returns the sum of squared differences between two unsigned half-byte vectors, one compressed. + */ + int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed); + + /** + * Returns the sum of squared differences between two unsigned half-byte vectors, both compressed. + */ + int int4SquareDistanceBothPacked(byte[] a, byte[] b); + /** Returns the sum of squared differences of the two unsigned byte vectors. */ int uint8SquareDistance(byte[] a, byte[] b); diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java index 24864318af5a..cf9c56c59774 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java @@ -109,6 +109,9 @@ public static VectorizationProvider getInstance() { /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ public abstract FlatVectorsScorer getLucene99FlatVectorsScorer(); + /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ + public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer(); + /** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */ public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException; diff --git a/lucene/core/src/java/org/apache/lucene/util/FloatToFloatFunction.java b/lucene/core/src/java/org/apache/lucene/util/FloatToFloatFunction.java new file mode 100644 index 000000000000..9068a5438361 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/FloatToFloatFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +/** + * Simple interface to map one float to another (useful in scaling scores). + * + * @lucene.internal + */ +@FunctionalInterface +public interface FloatToFloatFunction { + float apply(float f); +} diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 38b9cf6d67a5..db1f6fee083b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -113,6 +113,37 @@ public static int squareDistance(byte[] a, byte[] b) { return IMPL.squareDistance(a, b); } + /** Returns the sum of squared differences between two uint4 (values between [0,15]) vectors. */ + public static int int4SquareDistance(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + return IMPL.int4SquareDistance(a, b); + } + + /** + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors. The + * second vector is considered "packed" (i.e. every byte representing two values). + */ + public static int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) { + if (packed.length != ((unpacked.length + 1) >> 1)) { + throw new IllegalArgumentException( + "vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length); + } + return IMPL.int4SquareDistanceSinglePacked(unpacked, packed); + } + + /** + * Returns the sum of squared differences between two uint4 (values between [0,15]) vectors. Both + * vectors are considered "packed" (i.e. every byte representing two values). + */ + public static int int4SquareDistanceBothPacked(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + return IMPL.int4SquareDistanceBothPacked(a, b); + } + /** Returns the sum of squared differences of the two vectors where each byte is unsigned */ public static int uint8SquareDistance(byte[] a, byte[] b) { if (a.length != b.length) { @@ -189,15 +220,22 @@ public static int uint8DotProduct(byte[] a, byte[] b) { return IMPL.uint8DotProduct(a, b); } + /** + * Dot product computed over uint4 (values between [0,15]) bytes. + * + * @param a bytes containing a vector + * @param b bytes containing another vector, of the same dimension + * @return the value of the dot product of the two vectors + */ public static int int4DotProduct(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return IMPL.int4DotProduct(a, false, b, false); + return IMPL.int4DotProduct(a, b); } /** - * Dot product computed over int4 (values between [0,15]) bytes. The second vector is considered + * Dot product computed over uint4 (values between [0,15]) bytes. The second vector is considered * "packed" (i.e. every byte representing two values). The following packing is assumed: * *
@@ -211,12 +249,28 @@ public static int int4DotProduct(byte[] a, byte[] b) {
* @param packed the packed vector, of length {@code (unpacked.length + 1) / 2}
* @return the value of the dot product of the two vectors
*/
- public static int int4DotProductPacked(byte[] unpacked, byte[] packed) {
+ public static int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) {
if (packed.length != ((unpacked.length + 1) >> 1)) {
throw new IllegalArgumentException(
"vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length);
}
- return IMPL.int4DotProduct(unpacked, false, packed, true);
+ return IMPL.int4DotProductSinglePacked(unpacked, packed);
+ }
+
+ /**
+ * Dot product computed over uint4 (values between [0,15]) bytes. Both vectors are considered
+ * "packed" (i.e. every byte representing two values).
+ *
+ * @param a bytes containing a packed vector
+ * @param b bytes containing another packed vector, of the same dimension
+ * @return the value of the dot product of the two vectors
+ */
+ public static int int4DotProductBothPacked(byte[] a, byte[] b) {
+ if (a.length != b.length) {
+ throw new IllegalArgumentException(
+ "vector dimensions differ: " + a.length + " != " + b.length);
+ }
+ return IMPL.int4DotProductBothPacked(a, b);
}
/**
diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java
new file mode 100644
index 000000000000..12b95f6c2ff2
--- /dev/null
+++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.internal.vectorization;
+
+import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED;
+import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
+import java.nio.ByteOrder;
+import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.MemorySegmentAccessInput;
+import org.apache.lucene.util.FloatToFloatFunction;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
+import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
+import org.apache.lucene.util.quantization.ScalarQuantizer;
+
+class Lucene99MemorySegmentScalarQuantizedVectorScorer implements FlatVectorsScorer {
+ static final Lucene99MemorySegmentScalarQuantizedVectorScorer INSTANCE =
+ new Lucene99MemorySegmentScalarQuantizedVectorScorer();
+
+ private static final FlatVectorsScorer DELEGATE =
+ new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
+
+ private Lucene99MemorySegmentScalarQuantizedVectorScorer() {}
+
+ @Override
+ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
+ throws IOException {
+ if (vectorValues instanceof QuantizedByteVectorValues quantized
+ && quantized.getSlice() instanceof MemorySegmentAccessInput input) {
+ return new RandomVectorScorerSupplierImpl(similarityFunction, quantized, input);
+ }
+ return DELEGATE.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
+ throws IOException {
+ if (vectorValues instanceof QuantizedByteVectorValues quantized
+ && quantized.getSlice() instanceof MemorySegmentAccessInput input) {
+ return new RandomVectorScorerImpl(similarityFunction, quantized, input, target);
+ }
+ return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
+ throws IOException {
+ return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene99MemorySegmentScalarQuantizedVectorScorer()";
+ }
+
+ private abstract static class RandomVectorScorerBase
+ extends RandomVectorScorer.AbstractRandomVectorScorer {
+
+ private final ScalarQuantizer quantizer;
+ private final float constMultiplier;
+ private final MemorySegmentAccessInput input;
+ private final int vectorByteSize;
+ private final int nodeSize;
+ private final Scorer scorer;
+ private final FloatToFloatFunction scaler;
+ private byte[] scratch;
+
+ RandomVectorScorerBase(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues values,
+ MemorySegmentAccessInput input) {
+ super(values);
+
+ this.quantizer = values.getScalarQuantizer();
+ this.constMultiplier = this.quantizer.getConstantMultiplier();
+ this.input = input;
+ this.vectorByteSize = values.getVectorByteLength();
+ this.nodeSize = this.vectorByteSize + Float.BYTES;
+
+ this.scorer =
+ switch (similarityFunction) {
+ case EUCLIDEAN -> {
+ if (this.quantizer.getBits() <= 4) {
+ if (this.vectorByteSize != values.dimension()) {
+ yield this::compressedInt4Euclidean;
+ }
+ yield this::int4Euclidean;
+ }
+ yield this::euclidean;
+ }
+ case DOT_PRODUCT, COSINE, MAXIMUM_INNER_PRODUCT -> {
+ if (this.quantizer.getBits() <= 4) {
+ if (this.vectorByteSize != values.dimension()) {
+ yield this::compressedInt4DotProduct;
+ }
+ yield this::int4DotProduct;
+ }
+ yield this::dotProduct;
+ }
+ };
+
+ this.scaler =
+ switch (similarityFunction) {
+ case EUCLIDEAN -> VectorUtil::normalizeDistanceToUnitInterval;
+ case DOT_PRODUCT, COSINE -> VectorUtil::normalizeToUnitInterval;
+ case MAXIMUM_INNER_PRODUCT -> VectorUtil::scaleMaxInnerProductScore;
+ };
+
+ checkInvariants();
+ }
+
+ final void checkInvariants() {
+ if (input.length() < (long) nodeSize * maxOrd()) {
+ throw new IllegalArgumentException("input length is less than expected vector data");
+ }
+ }
+
+ final void checkOrdinal(int ord) {
+ if (ord < 0 || ord >= maxOrd()) {
+ throw new IllegalArgumentException("illegal ordinal: " + ord);
+ }
+ }
+
+ ScalarQuantizer getQuantizer() {
+ return quantizer;
+ }
+
+ private static final ValueLayout.OfInt INT_UNALIGNED_LE =
+ JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+
+ @SuppressWarnings("restricted")
+ Node getNode(int ord) throws IOException {
+ checkOrdinal(ord);
+ long byteOffset = (long) ord * nodeSize;
+ MemorySegment node = input.segmentSliceOrNull(byteOffset, nodeSize);
+ if (node == null) {
+ if (scratch == null) {
+ scratch = new byte[nodeSize];
+ }
+ input.readBytes(byteOffset, scratch, 0, nodeSize);
+ node = MemorySegment.ofArray(scratch);
+ }
+ return new Node(
+ node.reinterpret(vectorByteSize),
+ Float.intBitsToFloat(node.get(INT_UNALIGNED_LE, vectorByteSize)));
+ }
+
+ float scoreBody(int ord, float queryOffset) throws IOException {
+ checkOrdinal(ord);
+ Node node = getNode(ord);
+ return scaler.apply(scorer.score(node.vector) * constMultiplier + node.offset + queryOffset);
+ }
+
+ abstract int euclidean(MemorySegment doc);
+
+ abstract int int4Euclidean(MemorySegment doc);
+
+ abstract int compressedInt4Euclidean(MemorySegment doc);
+
+ abstract int dotProduct(MemorySegment doc);
+
+ abstract int int4DotProduct(MemorySegment doc);
+
+ abstract int compressedInt4DotProduct(MemorySegment doc);
+
+ record Node(MemorySegment vector, float offset) {}
+
+ @FunctionalInterface
+ private interface Scorer {
+ int score(MemorySegment doc) throws IOException;
+ }
+ }
+
+ private static class RandomVectorScorerImpl extends RandomVectorScorerBase {
+ private final byte[] targetBytes;
+ private final float queryOffset;
+
+ RandomVectorScorerImpl(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues values,
+ MemorySegmentAccessInput input,
+ float[] target) {
+ super(similarityFunction, values, input);
+ this.targetBytes = new byte[target.length];
+ this.queryOffset = quantizeQuery(target, targetBytes, similarityFunction, getQuantizer());
+ }
+
+ @Override
+ public float score(int node) throws IOException {
+ return scoreBody(node, queryOffset);
+ }
+
+ @Override
+ int euclidean(MemorySegment doc) {
+ return PanamaVectorUtilSupport.uint8SquareDistance(targetBytes, doc);
+ }
+
+ @Override
+ int int4Euclidean(MemorySegment doc) {
+ return PanamaVectorUtilSupport.int4SquareDistance(targetBytes, doc);
+ }
+
+ @Override
+ int compressedInt4Euclidean(MemorySegment doc) {
+ return PanamaVectorUtilSupport.int4SquareDistanceSinglePacked(targetBytes, doc);
+ }
+
+ @Override
+ int dotProduct(MemorySegment doc) {
+ return PanamaVectorUtilSupport.uint8DotProduct(targetBytes, doc);
+ }
+
+ @Override
+ int int4DotProduct(MemorySegment doc) {
+ return PanamaVectorUtilSupport.int4DotProduct(targetBytes, doc);
+ }
+
+ @Override
+ int compressedInt4DotProduct(MemorySegment doc) {
+ return PanamaVectorUtilSupport.int4DotProductSinglePacked(targetBytes, doc);
+ }
+ }
+
+ private record RandomVectorScorerSupplierImpl(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues values,
+ MemorySegmentAccessInput input)
+ implements RandomVectorScorerSupplier {
+
+ @Override
+ public UpdateableRandomVectorScorer scorer() {
+ return new UpdateableRandomVectorScorerImpl(similarityFunction, values, input);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() {
+ return new RandomVectorScorerSupplierImpl(similarityFunction, values, input);
+ }
+ }
+
+ private static class UpdateableRandomVectorScorerImpl extends RandomVectorScorerBase
+ implements UpdateableRandomVectorScorer {
+ private MemorySegment query;
+ private float queryOffset;
+
+ UpdateableRandomVectorScorerImpl(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues values,
+ MemorySegmentAccessInput input) {
+ super(similarityFunction, values, input);
+ }
+
+ @Override
+ public void setScoringOrdinal(int ord) throws IOException {
+ checkOrdinal(ord);
+ Node node = getNode(ord);
+ query = node.vector;
+ queryOffset = node.offset;
+ }
+
+ @Override
+ public float score(int node) throws IOException {
+ return scoreBody(node, queryOffset);
+ }
+
+ @Override
+ int euclidean(MemorySegment doc) {
+ return PanamaVectorUtilSupport.uint8SquareDistance(query, doc);
+ }
+
+ @Override
+ int int4Euclidean(MemorySegment doc) {
+ return PanamaVectorUtilSupport.int4SquareDistance(query, doc);
+ }
+
+ @Override
+ int compressedInt4Euclidean(MemorySegment doc) {
+ return PanamaVectorUtilSupport.int4SquareDistanceBothPacked(query, doc);
+ }
+
+ @Override
+ int dotProduct(MemorySegment doc) {
+ return PanamaVectorUtilSupport.uint8DotProduct(query, doc);
+ }
+
+ @Override
+ int int4DotProduct(MemorySegment doc) {
+ return PanamaVectorUtilSupport.int4DotProduct(query, doc);
+ }
+
+ @Override
+ int compressedInt4DotProduct(MemorySegment doc) {
+ return PanamaVectorUtilSupport.int4DotProductBothPacked(query, doc);
+ }
+ }
+}
diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
index a77c4846ca2a..ba612f750040 100644
--- a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
+++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
@@ -360,7 +360,7 @@ public byte tail(int index) {
@Override
public int dotProduct(byte[] a, byte[] b) {
- return dotProductBody(new ArrayLoader(a), new ArrayLoader(b));
+ return dotProductBody(new ArrayLoader(a), new ArrayLoader(b), true);
}
@Override
@@ -369,15 +369,19 @@ public int uint8DotProduct(byte[] a, byte[] b) {
}
public static int dotProduct(byte[] a, MemorySegment b) {
- return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b));
+ return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b), true);
}
public static int dotProduct(MemorySegment a, MemorySegment b) {
- return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
+ return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), true);
}
- private static int dotProductBody(ByteVectorLoader a, ByteVectorLoader b) {
- return dotProductBody(a, b, true);
+ public static int uint8DotProduct(byte[] a, MemorySegment b) {
+ return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b), false);
+ }
+
+ public static int uint8DotProduct(MemorySegment a, MemorySegment b) {
+ return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), false);
}
private static int dotProductBody(ByteVectorLoader a, ByteVectorLoader b, boolean signed) {
@@ -479,178 +483,198 @@ private static int dotProductBody128(
return acc.reduceLanes(ADD);
}
+ private static class Int4Constants {
+ static final VectorSpecies BYTE_SPECIES;
+ static final VectorSpecies SHORT_SPECIES;
+ static final int CHUNK;
+
+ static {
+ if (VECTOR_BITSIZE >= 512) {
+ BYTE_SPECIES = ByteVector.SPECIES_256;
+ SHORT_SPECIES = ShortVector.SPECIES_512;
+ CHUNK = 4096;
+ } else if (VECTOR_BITSIZE == 256) {
+ BYTE_SPECIES = ByteVector.SPECIES_128;
+ SHORT_SPECIES = ShortVector.SPECIES_256;
+ CHUNK = 2048;
+ } else {
+ BYTE_SPECIES = ByteVector.SPECIES_64;
+ SHORT_SPECIES = ShortVector.SPECIES_128;
+ CHUNK = 1024;
+ }
+ }
+ }
+
@Override
- public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) {
- assert (apacked && bpacked) == false;
+ public int int4DotProduct(byte[] a, byte[] b) {
+ return int4DotProductBody(new ArrayLoader(a), new ArrayLoader(b));
+ }
+
+ public static int int4DotProduct(byte[] a, MemorySegment b) {
+ return int4DotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b));
+ }
+
+ public static int int4DotProduct(MemorySegment a, MemorySegment b) {
+ return int4DotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
+ }
+
+ private static int int4DotProductBody(ByteVectorLoader a, ByteVectorLoader b) {
int i = 0;
int res = 0;
- if (apacked || bpacked) {
- byte[] packed = apacked ? a : b;
- byte[] unpacked = apacked ? b : a;
- if (packed.length >= 32) {
- if (VECTOR_BITSIZE >= 512) {
- i += ByteVector.SPECIES_256.loopBound(packed.length);
- res += dotProductBody512Int4Packed(unpacked, packed, i);
- } else if (VECTOR_BITSIZE == 256) {
- i += ByteVector.SPECIES_128.loopBound(packed.length);
- res += dotProductBody256Int4Packed(unpacked, packed, i);
- } else {
- i += ByteVector.SPECIES_64.loopBound(packed.length);
- res += dotProductBody128Int4Packed(unpacked, packed, i);
- }
- }
- // scalar tail
- for (; i < packed.length; i++) {
- byte packedByte = packed[i];
- byte unpacked1 = unpacked[i];
- byte unpacked2 = unpacked[i + packed.length];
- res += (packedByte & 0x0F) * unpacked2;
- res += ((packedByte & 0xFF) >> 4) * unpacked1;
- }
- } else {
- if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) {
- return dotProduct(a, b);
- } else if (a.length >= 32) {
- i += ByteVector.SPECIES_128.loopBound(a.length);
- res += int4DotProductBody128(a, b, i);
- }
- // scalar tail
- for (; i < a.length; i++) {
- res += b[i] * a[i];
- }
+ if (a.length() >= 32) {
+ i += Int4Constants.BYTE_SPECIES.loopBound(a.length());
+ res += int4DotProductBody(a, b, i);
+ }
+ // scalar tail
+ for (; i < a.length(); i++) {
+ res += a.tail(i) * b.tail(i);
}
-
return res;
}
- private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) {
+ private static int int4DotProductBody(ByteVectorLoader a, ByteVectorLoader b, int limit) {
int sum = 0;
- // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator
- for (int i = 0; i < limit; i += 4096) {
- ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_512);
- ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_512);
- int innerLimit = Math.min(limit - i, 4096);
- for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) {
- // packed
- var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j);
+ // iterate in chunks to ensure we don't overflow the short accumulator
+ for (int i = 0; i < limit; i += Int4Constants.CHUNK) {
+ ShortVector acc = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ int innerLimit = Math.min(limit - i, Int4Constants.CHUNK);
+ for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) {
// unpacked
- var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length);
+ ByteVector vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j);
+ Vector vb16 = vb8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0);
- // upper
- ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
- Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0);
- acc0 = acc0.add(prod16);
+ // unpacked
+ ByteVector va8 = a.load(Int4Constants.BYTE_SPECIES, i + j);
+ Vector va16 = va8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0);
- // lower
- ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j);
- ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
- Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0);
- acc1 = acc1.add(prod16a);
+ acc = acc.add(vb16.mul(va16));
}
- IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts();
- IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts();
- IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts();
- IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts();
- sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD);
+ Vector intAcc0 = acc.convert(S2I, 0);
+ Vector intAcc1 = acc.convert(S2I, 1);
+ sum += intAcc0.add(intAcc1).reinterpretAsInts().reduceLanes(ADD);
}
return sum;
}
- private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) {
+ @Override
+ public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) {
+ return int4DotProductSinglePackedBody(new ArrayLoader(unpacked), new ArrayLoader(packed));
+ }
+
+ public static int int4DotProductSinglePacked(byte[] unpacked, MemorySegment packed) {
+ return int4DotProductSinglePackedBody(
+ new ArrayLoader(unpacked), new MemorySegmentLoader(packed));
+ }
+
+ private static int int4DotProductSinglePackedBody(
+ ByteVectorLoader unpacked, ByteVectorLoader packed) {
+ int i = 0;
+ int res = 0;
+ if (packed.length() >= 32) {
+ i += Int4Constants.BYTE_SPECIES.loopBound(packed.length());
+ res += int4DotProductSinglePackedBody(unpacked, packed, i);
+ }
+ // scalar tail
+ for (; i < packed.length(); i++) {
+ byte packedByte = packed.tail(i);
+ byte unpacked1 = unpacked.tail(i);
+ byte unpacked2 = unpacked.tail(i + packed.length());
+ res += (packedByte & 0x0F) * unpacked2;
+ res += ((packedByte & 0xFF) >> 4) * unpacked1;
+ }
+ return res;
+ }
+
+ private static int int4DotProductSinglePackedBody(
+ ByteVectorLoader unpacked, ByteVectorLoader packed, int limit) {
int sum = 0;
- // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator
- for (int i = 0; i < limit; i += 2048) {
- ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256);
- ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256);
- int innerLimit = Math.min(limit - i, 2048);
- for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
+ // iterate in chunks to ensure we don't overflow the short accumulator
+ for (int i = 0; i < limit; i += Int4Constants.CHUNK) {
+ ShortVector acc0 = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ ShortVector acc1 = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ int innerLimit = Math.min(limit - i, Int4Constants.CHUNK);
+ for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) {
// packed
- var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j);
- // unpacked
- var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length);
+ ByteVector vb8 = packed.load(Int4Constants.BYTE_SPECIES, i + j);
// upper
+ ByteVector va8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j + packed.length());
ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
- Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0);
+ Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0);
acc0 = acc0.add(prod16);
// lower
- ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j);
+ ByteVector vc8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j);
ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8);
- Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0);
+ Vector prod16a =
+ prod8a.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0);
acc1 = acc1.add(prod16a);
}
- IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts();
- IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts();
- IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts();
- IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts();
- sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD);
+ Vector intAcc0 = acc0.convert(S2I, 0);
+ Vector intAcc1 = acc0.convert(S2I, 1);
+ Vector intAcc2 = acc1.convert(S2I, 0);
+ Vector intAcc3 = acc1.convert(S2I, 1);
+ sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reinterpretAsInts().reduceLanes(ADD);
}
return sum;
}
- /** vectorized dot product body (128 bit vectors) */
- private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) {
- int sum = 0;
- // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator
- for (int i = 0; i < limit; i += 1024) {
- ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128);
- ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128);
- int innerLimit = Math.min(limit - i, 1024);
- for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) {
- // packed
- ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j);
- // unpacked
- ByteVector va8 =
- ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length);
+ @Override
+ public int int4DotProductBothPacked(byte[] a, byte[] b) {
+ return int4DotProductBothPackedBody(new ArrayLoader(a), new ArrayLoader(b));
+ }
- // upper
- ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8);
- ShortVector prod16 =
- prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
- acc0 = acc0.add(prod16.and((short) 0xFF));
+ public static int int4DotProductBothPacked(MemorySegment a, MemorySegment b) {
+ return int4DotProductBothPackedBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
+ }
- // lower
- va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j);
- prod8 = vb8.lanewise(LSHR, 4).mul(va8);
- prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
- acc1 = acc1.add(prod16.and((short) 0xFF));
- }
- IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
- IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
- IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
- IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
- sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD);
+ private static int int4DotProductBothPackedBody(ByteVectorLoader a, ByteVectorLoader b) {
+ int i = 0;
+ int res = 0;
+ if (a.length() >= 32) {
+ i += Int4Constants.BYTE_SPECIES.loopBound(a.length());
+ res += int4DotProductBothPackedBody(a, b, i);
}
- return sum;
+ // scalar tail
+ for (; i < a.length(); i++) {
+ byte aByte = a.tail(i);
+ byte bByte = b.tail(i);
+ res += (aByte & 0x0F) * (bByte & 0x0F);
+ res += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4);
+ }
+ return res;
}
- private int int4DotProductBody128(byte[] a, byte[] b, int limit) {
+ private static int int4DotProductBothPackedBody(
+ ByteVectorLoader a, ByteVectorLoader b, int limit) {
int sum = 0;
- // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator
- for (int i = 0; i < limit; i += 1024) {
- ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128);
- ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128);
- int innerLimit = Math.min(limit - i, 1024);
- for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
- ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j);
- ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j);
- ByteVector prod8 = va8.mul(vb8);
- ShortVector prod16 =
- prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
- acc0 = acc0.add(prod16.and((short) 0xFF));
-
- va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8);
- vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8);
- prod8 = va8.mul(vb8);
- prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
- acc1 = acc1.add(prod16.and((short) 0xFF));
+ // iterate in chunks to ensure we don't overflow the short accumulator
+ for (int i = 0; i < limit; i += Int4Constants.CHUNK) {
+ ShortVector acc0 = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ ShortVector acc1 = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ int innerLimit = Math.min(limit - i, Int4Constants.CHUNK);
+ for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) {
+ // packed
+ var vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j);
+ // packed
+ var va8 = a.load(Int4Constants.BYTE_SPECIES, i + j);
+
+ // upper
+ ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8.and((byte) 0x0F));
+ Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0);
+ acc0 = acc0.add(prod16);
+
+ // lower
+ ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(va8.lanewise(LSHR, 4));
+ Vector prod16a =
+ prod8a.convertShape(ZERO_EXTEND_B2S, Int4Constants.SHORT_SPECIES, 0);
+ acc1 = acc1.add(prod16a);
}
- IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
- IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
- IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
- IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
- sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD);
+ Vector intAcc0 = acc0.convert(S2I, 0);
+ Vector intAcc1 = acc0.convert(S2I, 1);
+ Vector intAcc2 = acc1.convert(S2I, 0);
+ Vector intAcc3 = acc1.convert(S2I, 1);
+ sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reinterpretAsInts().reduceLanes(ADD);
}
return sum;
}
@@ -788,7 +812,7 @@ private static float[] cosineBody128(ByteVectorLoader a, ByteVectorLoader b, int
@Override
public int squareDistance(byte[] a, byte[] b) {
- return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b));
+ return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b), true);
}
@Override
@@ -797,15 +821,19 @@ public int uint8SquareDistance(byte[] a, byte[] b) {
}
public static int squareDistance(MemorySegment a, MemorySegment b) {
- return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
+ return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), true);
}
public static int squareDistance(byte[] a, MemorySegment b) {
- return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b));
+ return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b), true);
}
- private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b) {
- return squareDistanceBody(a, b, true);
+ public static int uint8SquareDistance(MemorySegment a, MemorySegment b) {
+ return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b), false);
+ }
+
+ public static int uint8SquareDistance(byte[] a, MemorySegment b) {
+ return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b), false);
}
private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b, boolean signed) {
@@ -886,6 +914,183 @@ private static int squareDistanceBody128(
return acc1.add(acc2).reduceLanes(ADD);
}
+ @Override
+ public int int4SquareDistance(byte[] a, byte[] b) {
+ return int4SquareDistanceBody(new ArrayLoader(a), new ArrayLoader(b));
+ }
+
+ public static int int4SquareDistance(byte[] a, MemorySegment b) {
+ return int4SquareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b));
+ }
+
+ public static int int4SquareDistance(MemorySegment a, MemorySegment b) {
+ return int4SquareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
+ }
+
+ private static int int4SquareDistanceBody(ByteVectorLoader a, ByteVectorLoader b) {
+ int i = 0;
+ int res = 0;
+ if (a.length() >= 32) {
+ i += Int4Constants.BYTE_SPECIES.loopBound(a.length());
+ res += int4SquareDistanceBody(a, b, i);
+ }
+ // scalar tail
+ for (; i < a.length(); i++) {
+ int diff = a.tail(i) - b.tail(i);
+ res += diff * diff;
+ }
+ return res;
+ }
+
+ private static int int4SquareDistanceBody(ByteVectorLoader a, ByteVectorLoader b, int limit) {
+ int sum = 0;
+ // iterate in chunks to ensure we don't overflow the short accumulator
+ for (int i = 0; i < limit; i += Int4Constants.CHUNK) {
+ ShortVector acc = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ int innerLimit = Math.min(limit - i, Int4Constants.CHUNK);
+ for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) {
+ // unpacked
+ var vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j);
+ // unpacked
+ var va8 = a.load(Int4Constants.BYTE_SPECIES, i + j);
+
+ ByteVector diff8 = vb8.sub(va8);
+ Vector diff16 = diff8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0);
+ acc = acc.add(diff16.mul(diff16));
+ }
+ Vector intAcc0 = acc.convert(S2I, 0);
+ Vector intAcc1 = acc.convert(S2I, 1);
+ sum += intAcc0.add(intAcc1).reinterpretAsInts().reduceLanes(ADD);
+ }
+ return sum;
+ }
+
+ @Override
+ public int int4SquareDistanceSinglePacked(byte[] a, byte[] b) {
+ return int4SquareDistanceSinglePackedBody(new ArrayLoader(a), new ArrayLoader(b));
+ }
+
+ public static int int4SquareDistanceSinglePacked(byte[] a, MemorySegment b) {
+ return int4SquareDistanceSinglePackedBody(new ArrayLoader(a), new MemorySegmentLoader(b));
+ }
+
+ private static int int4SquareDistanceSinglePackedBody(
+ ByteVectorLoader unpacked, ByteVectorLoader packed) {
+ int i = 0;
+ int res = 0;
+ if (packed.length() >= 32) {
+ i += Int4Constants.BYTE_SPECIES.loopBound(packed.length());
+ res += int4SquareDistanceSinglePackedBody(unpacked, packed, i);
+ }
+ // scalar tail
+ for (; i < packed.length(); i++) {
+ byte packedByte = packed.tail(i);
+ byte unpacked1 = unpacked.tail(i);
+ byte unpacked2 = unpacked.tail(i + packed.length());
+
+ int diff1 = (packedByte & 0x0F) - unpacked2;
+ int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1;
+
+ res += diff1 * diff1 + diff2 * diff2;
+ }
+ return res;
+ }
+
+ private static int int4SquareDistanceSinglePackedBody(
+ ByteVectorLoader unpacked, ByteVectorLoader packed, int limit) {
+ int sum = 0;
+ // iterate in chunks to ensure we don't overflow the short accumulator
+ for (int i = 0; i < limit; i += Int4Constants.CHUNK) {
+ ShortVector acc0 = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ ShortVector acc1 = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ int innerLimit = Math.min(limit - i, Int4Constants.CHUNK);
+ for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) {
+ // packed
+ ByteVector vb8 = packed.load(Int4Constants.BYTE_SPECIES, i + j);
+
+ // upper
+ ByteVector va8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j + packed.length());
+ ByteVector diff8 = vb8.and((byte) 0x0F).sub(va8);
+ Vector diff16 = diff8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0);
+ acc0 = acc0.add(diff16.mul(diff16));
+
+ // lower
+ ByteVector vc8 = unpacked.load(Int4Constants.BYTE_SPECIES, i + j);
+ ByteVector diff8a = vb8.lanewise(LSHR, 4).sub(vc8);
+ Vector diff16a = diff8a.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0);
+ acc1 = acc1.add(diff16a.mul(diff16a));
+ }
+ Vector intAcc0 = acc0.convert(S2I, 0);
+ Vector intAcc1 = acc0.convert(S2I, 1);
+ Vector intAcc2 = acc1.convert(S2I, 0);
+ Vector intAcc3 = acc1.convert(S2I, 1);
+ sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reinterpretAsInts().reduceLanes(ADD);
+ }
+ return sum;
+ }
+
+ @Override
+ public int int4SquareDistanceBothPacked(byte[] a, byte[] b) {
+ return int4SquareDistanceBothPackedBody(new ArrayLoader(a), new ArrayLoader(b));
+ }
+
+ public static int int4SquareDistanceBothPacked(MemorySegment a, MemorySegment b) {
+ return int4SquareDistanceBothPackedBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
+ }
+
+ private static int int4SquareDistanceBothPackedBody(ByteVectorLoader a, ByteVectorLoader b) {
+ int i = 0;
+ int res = 0;
+ if (a.length() >= 32) {
+ i += Int4Constants.BYTE_SPECIES.loopBound(a.length());
+ res += int4SquareDistanceBothPackedBody(a, b, i);
+ }
+ // scalar tail
+ for (; i < a.length(); i++) {
+ byte aByte = a.tail(i);
+ byte bByte = b.tail(i);
+
+ int diff1 = (aByte & 0x0F) - (bByte & 0x0F);
+ int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4);
+
+ res += diff1 * diff1 + diff2 * diff2;
+ }
+ return res;
+ }
+
+ private static int int4SquareDistanceBothPackedBody(
+ ByteVectorLoader a, ByteVectorLoader b, int limit) {
+ int sum = 0;
+ // iterate in chunks to ensure we don't overflow the short accumulator
+ for (int i = 0; i < limit; i += Int4Constants.CHUNK) {
+ ShortVector acc0 = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ ShortVector acc1 = ShortVector.zero(Int4Constants.SHORT_SPECIES);
+ int innerLimit = Math.min(limit - i, Int4Constants.CHUNK);
+ for (int j = 0; j < innerLimit; j += Int4Constants.BYTE_SPECIES.length()) {
+ // packed
+ var vb8 = b.load(Int4Constants.BYTE_SPECIES, i + j);
+ // packed
+ var va8 = a.load(Int4Constants.BYTE_SPECIES, i + j);
+
+ // upper
+ ByteVector diff8 = vb8.and((byte) 0x0F).sub(va8.and((byte) 0x0F));
+ Vector diff16 = diff8.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0);
+ acc0 = acc0.add(diff16.mul(diff16));
+
+ // lower
+ ByteVector diff8a = vb8.lanewise(LSHR, 4).sub(va8.lanewise(LSHR, 4));
+ Vector diff16a = diff8a.convertShape(B2S, Int4Constants.SHORT_SPECIES, 0);
+ acc1 = acc1.add(diff16a.mul(diff16a));
+ }
+ Vector intAcc0 = acc0.convert(S2I, 0);
+ Vector intAcc1 = acc0.convert(S2I, 1);
+ Vector intAcc2 = acc1.convert(S2I, 0);
+ Vector intAcc3 = acc1.convert(S2I, 1);
+ sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reinterpretAsInts().reduceLanes(ADD);
+ }
+ return sum;
+ }
+
// Experiments suggest that we need at least 8 lanes so that the overhead of going with the vector
// approach and counting trues on vector masks pays off.
private static final boolean ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO = INT_SPECIES.length() >= 8;
diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java
index 54b3be67afcb..cf3ab94f417c 100644
--- a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java
+++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java
@@ -78,6 +78,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return Lucene99MemorySegmentFlatVectorsScorer.INSTANCE;
}
+ @Override
+ public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
+ return Lucene99MemorySegmentScalarQuantizedVectorScorer.INSTANCE;
+ }
+
@Override
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException {
if (input instanceof MemorySegmentAccessInput msai) {
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java
index 2c6c54cece73..3ad2cab88690 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java
@@ -308,10 +308,19 @@ public KnnVectorsFormat knnVectorsFormat() {
}
};
String expectedPattern =
- "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))";
- var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
+ "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=%s, rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s)))";
+ var defaultScorer =
+ format(
+ Locale.ROOT,
+ expectedPattern,
+ "ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())",
+ "DefaultFlatVectorScorer()");
var memSegScorer =
- format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
+ format(
+ Locale.ROOT,
+ expectedPattern,
+ "Lucene99MemorySegmentScalarQuantizedVectorScorer()",
+ "Lucene99MemorySegmentFlatVectorsScorer()");
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java
index e04054c27e37..7156afd9cc3c 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java
@@ -372,10 +372,19 @@ public KnnVectorsFormat knnVectorsFormat() {
}
};
String expectedPattern =
- "Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))";
- var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
+ "Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=%s, rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s))";
+ var defaultScorer =
+ format(
+ Locale.ROOT,
+ expectedPattern,
+ "ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())",
+ "DefaultFlatVectorScorer()");
var memSegScorer =
- format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
+ format(
+ Locale.ROOT,
+ expectedPattern,
+ "Lucene99MemorySegmentScalarQuantizedVectorScorer()",
+ "Lucene99MemorySegmentFlatVectorsScorer()");
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
}
diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java
index 7ec661b3659f..78280e7e4c36 100644
--- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java
+++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java
@@ -107,11 +107,23 @@ public void testInt4DotProduct() {
b[i] = (byte) random().nextInt(16);
}
- assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true));
- assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false));
+ assertIntReturningProviders(p -> p.int4DotProduct(a, b));
+ assertIntReturningProviders(p -> p.int4DotProductSinglePacked(a, pack(b)));
+ assertIntReturningProviders(p -> p.int4DotProductSinglePacked(b, pack(a)));
+ assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b)));
+
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b)));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a)));
assertEquals(
LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
- PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true));
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b)));
}
public void testInt4DotProductBoundaries() {
@@ -122,20 +134,106 @@ public void testInt4DotProductBoundaries() {
Arrays.fill(a, MAX_VALUE);
Arrays.fill(b, MAX_VALUE);
- assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true));
- assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false));
+
+ assertIntReturningProviders(p -> p.int4DotProduct(a, b));
+ assertIntReturningProviders(p -> p.int4DotProductSinglePacked(a, pack(b)));
+ assertIntReturningProviders(p -> p.int4DotProductSinglePacked(b, pack(a)));
+ assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b)));
+
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b)));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a)));
assertEquals(
LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
- PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true));
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b)));
byte MIN_VALUE = 0;
Arrays.fill(a, MIN_VALUE);
Arrays.fill(b, MIN_VALUE);
- assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true));
- assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false));
+
+ assertIntReturningProviders(p -> p.int4DotProduct(a, b));
+ assertIntReturningProviders(p -> p.int4DotProductSinglePacked(a, pack(b)));
+ assertIntReturningProviders(p -> p.int4DotProductSinglePacked(b, pack(a)));
+ assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b)));
+
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b)));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a)));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b)));
+ }
+
+ public void testInt4SquareDistance() {
+ assumeTrue("even sizes only", size % 2 == 0);
+ var a = new byte[size];
+ var b = new byte[size];
+ for (int i = 0; i < size; ++i) {
+ a[i] = (byte) random().nextInt(16);
+ b[i] = (byte) random().nextInt(16);
+ }
+
+ assertIntReturningProviders(p -> p.int4SquareDistance(a, b));
+ assertIntReturningProviders(p -> p.int4SquareDistanceSinglePacked(a, pack(b)));
+ assertIntReturningProviders(p -> p.int4SquareDistanceSinglePacked(b, pack(a)));
+ assertIntReturningProviders(p -> p.int4SquareDistanceBothPacked(pack(a), pack(b)));
+
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistance(a, b));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistanceSinglePacked(a, pack(b)));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistanceSinglePacked(b, pack(a)));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().squareDistance(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4SquareDistanceBothPacked(pack(a), pack(b)));
+ }
+
+ public void testInt4SquareDistanceBoundaries() {
+ assumeTrue("even sizes only", size % 2 == 0);
+
+ // squareDistance is maximized when the points are farther away
+
+ byte MAX_VALUE = 15;
+ var a = new byte[size];
+ Arrays.fill(a, MAX_VALUE);
+
+ byte MIN_VALUE = 0;
+ var b = new byte[size];
+ Arrays.fill(b, MIN_VALUE);
+
+ assertIntReturningProviders(p -> p.int4DotProduct(a, b));
+ assertIntReturningProviders(p -> p.int4DotProductSinglePacked(a, pack(b)));
+ assertIntReturningProviders(p -> p.int4DotProductSinglePacked(b, pack(a)));
+ assertIntReturningProviders(p -> p.int4DotProductBothPacked(pack(a), pack(b)));
+
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, b));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(a, pack(b)));
+ assertEquals(
+ LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductSinglePacked(b, pack(a)));
assertEquals(
LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b),
- PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true));
+ PANAMA_PROVIDER.getVectorUtilSupport().int4DotProductBothPacked(pack(a), pack(b)));
}
public void testInt4BitDotProduct() {