diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java index 808d7b3cc882..7a92c077bbc7 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java @@ -37,4 +37,12 @@ private FlatVectorScorerUtil() {} public static FlatVectorsScorer getLucene99FlatVectorsScorer() { return IMPL.getLucene99FlatVectorsScorer(); } + + /** + * Returns a FlatVectorsScorer that supports the Lucene99 format. Scorers retrieved through this + * method may be optimized on certain platforms. Otherwise, a DefaultFlatVectorScorer is returned. + */ + public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return IMPL.getLucene99ScalarQuantizedVectorsScorer(); + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 552260894a8d..3533080c963a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -18,10 +18,10 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -68,7 +68,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { final byte bits; final boolean compress; - final Lucene99ScalarQuantizedVectorScorer flatVectorScorer; + final FlatVectorsScorer flatVectorScorer; /** Constructs a format using default graph construction parameters */ public Lucene99ScalarQuantizedVectorsFormat() { @@ -109,8 +109,7 @@ public Lucene99ScalarQuantizedVectorsFormat( this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; this.compress = compress; - this.flatVectorScorer = - new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer(); } public static float calculateDefaultConfidenceInterval(int vectorDimension) { diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index 2127a594117a..60ec3e572697 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -19,6 +19,7 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; import org.apache.lucene.store.IndexInput; /** Default provider returning scalar implementations. */ @@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() { return DefaultFlatVectorScorer.INSTANCE; } + @Override + public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return new Lucene99ScalarQuantizedVectorScorer(getLucene99FlatVectorsScorer()); + } + @Override public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) { return new DefaultPostingDecodingUtil(input); diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java index eeb1830fc916..65c411d72466 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java @@ -97,6 +97,9 @@ public static VectorizationProvider getInstance() { /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ public abstract FlatVectorsScorer getLucene99FlatVectorsScorer(); + /** Returns a FlatVectorsScorer that supports scalar quantized vectors in the Lucene99 format. */ + public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer(); + /** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */ public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException; diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java new file mode 100644 index 000000000000..89a7e96674e0 --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; + +public class Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer implements FlatVectorsScorer { + + private final FlatVectorsScorer delegate; + + public Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer(FlatVectorsScorer delegate) { + this.delegate = delegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + // Unoptimized edge case, we don't optimize compressed 4-bit quantization with Euclidean + // similarity + // So, we delegate to the default scorer + if (quantizedByteVectorValues.getScalarQuantizer().getBits() == 4 + && similarityType == VectorSimilarityFunction.EUCLIDEAN + // Indicates that the vector is compressed as the byte length is not equal to the + // dimension count + && vectorValues.getVectorByteLength() != vectorValues.dimension()) { + return delegate.getRandomVectorScorerSupplier(similarityType, vectorValues); + } + var scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); + var scalarScorerSupplier = + Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.create( + similarityType, + scalarQuantizer.getBits(), + scalarQuantizer.getConstantMultiplier(), + quantizedByteVectorValues); + if (scalarScorerSupplier.isPresent()) { + return scalarScorerSupplier.get(); + } + } + return delegate.getRandomVectorScorerSupplier(similarityType, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityType, + RandomAccessVectorValues vectorValues, + float[] queryVector) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + // Unoptimized edge case, we don't optimize compressed 4-bit quantization with Euclidean + // similarity + // So, we delegate to the default scorer + if (quantizedByteVectorValues.getScalarQuantizer().getBits() == 4 + && similarityType == VectorSimilarityFunction.EUCLIDEAN + // Indicates that the vector is compressed as the byte length is not equal to the + // dimension count + && vectorValues.getVectorByteLength() != vectorValues.dimension()) { + return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector); + } + checkDimensions(queryVector.length, vectorValues.dimension()); + var scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); + byte[] targetBytes = new byte[queryVector.length]; + float offsetCorrection = + quantizeQuery(queryVector, targetBytes, similarityType, scalarQuantizer); + var scalarScorer = + Lucene99MemorySegmentScalarQuantizedVectorScorer.create( + similarityType, + targetBytes, + offsetCorrection, + scalarQuantizer.getConstantMultiplier(), + scalarQuantizer.getBits(), + quantizedByteVectorValues); + if (scalarScorer.isPresent()) { + return scalarScorer.get(); + } + } + return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityType, + RandomAccessVectorValues vectorValues, + byte[] queryVector) + throws IOException { + return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector); + } + + static void checkDimensions(int queryLen, int fieldLen) { + if (queryLen != fieldLen) { + throw new IllegalArgumentException( + "vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen); + } + } + + @Override + public String toString() { + return "Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer()"; + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..c6ef7f6cea83 --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; + +abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer + extends RandomVectorScorer.AbstractRandomVectorScorer { + + final int vectorByteSize, vectorByteOffset; + final MemorySegmentAccessInput input; + final MemorySegment query; + final float constMultiplier; + byte[] scratch; + + /** + * Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is + * returned. + */ + public static Optional create( + VectorSimilarityFunction similarityType, + byte[] targetBytes, + float offsetCorrection, + float constMultiplier, + byte bits, + RandomAccessQuantizedByteVectorValues values) { + IndexInput input = values.getSlice(); + if (input == null) { + return Optional.empty(); + } + input = FilterIndexInput.unwrapOnlyTest(input); + if (!(input instanceof MemorySegmentAccessInput msInput)) { + return Optional.empty(); + } + checkInvariants(values.size(), values.getVectorByteLength(), input); + final boolean compressed = values.getVectorByteLength() != values.dimension(); + if (compressed) { + assert bits == 4; + assert values.getVectorByteLength() == values.dimension() / 2; + } + return switch (similarityType) { + case COSINE, DOT_PRODUCT -> { + if (bits == 4) { + yield Optional.of( + new Int4DotProductScorer( + msInput, values, targetBytes, constMultiplier, offsetCorrection, compressed)); + } + yield Optional.of( + new DotProductScorer(msInput, values, targetBytes, constMultiplier, offsetCorrection)); + } + case EUCLIDEAN -> Optional.of( + new EuclideanScorer(msInput, values, targetBytes, constMultiplier)); + case MAXIMUM_INNER_PRODUCT -> { + if (bits == 4) { + yield Optional.of( + new Int4MaxInnerProductScorer( + msInput, values, targetBytes, constMultiplier, offsetCorrection, compressed)); + } + yield Optional.of( + new MaxInnerProductScorer( + msInput, values, targetBytes, constMultiplier, offsetCorrection)); + } + }; + } + + Lucene99MemorySegmentScalarQuantizedVectorScorer( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] queryVector, + float constMultiplier) { + super(values); + this.input = input; + this.vectorByteSize = values.getVectorByteLength(); + this.vectorByteOffset = values.getVectorByteLength() + Float.BYTES; + this.query = MemorySegment.ofArray(queryVector); + this.constMultiplier = constMultiplier; + } + + final MemorySegment getSegment(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * vectorByteOffset; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch == null) { + scratch = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch); + } + return seg; + } + + final float getOffsetCorrection(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = ((long) ord * vectorByteOffset) + vectorByteSize; + int floatInts = input.readInt(byteOffset); + return Float.intBitsToFloat(floatInts); + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + private final float offsetCorrection; + + DotProductScorer( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier, + float offsetCorrection) { + super(input, values, query, constMultiplier); + this.offsetCorrection = offsetCorrection; + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float dotProduct = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + float vectorOffset = getOffsetCorrection(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return Math.max((1 + adjustedDistance) / 2, 0); + } + } + + static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + private final boolean compressed; + private final float offsetCorrection; + + Int4DotProductScorer( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier, + float offsetCorrection, + boolean compressed) { + super(input, values, query, constMultiplier); + this.compressed = compressed; + this.offsetCorrection = offsetCorrection; + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float dotProduct = + PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), compressed); + float vectorOffset = getOffsetCorrection(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return Math.max((1 + adjustedDistance) / 2, 0); + } + } + + static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + EuclideanScorer( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier) { + super(input, values, query, constMultiplier); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node)); + float adjustedDistance = raw * constMultiplier; + return 1 / (1f + adjustedDistance); + } + } + + static final class MaxInnerProductScorer + extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + private final float offsetCorrection; + + MaxInnerProductScorer( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier, + float offsetCorrection) { + super(input, values, query, constMultiplier); + this.offsetCorrection = offsetCorrection; + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float dotProduct = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + float vectorOffset = getOffsetCorrection(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); + } + return adjustedDistance + 1; + } + } + + static final class Int4MaxInnerProductScorer + extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + private final boolean compressed; + private final float offsetCorrection; + + Int4MaxInnerProductScorer( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier, + float offsetCorrection, + boolean compressed) { + super(input, values, query, constMultiplier); + this.compressed = compressed; + this.offsetCorrection = offsetCorrection; + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float dotProduct = + PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), compressed); + float vectorOffset = getOffsetCorrection(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); + } + return adjustedDistance + 1; + } + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java new file mode 100644 index 000000000000..787bfcfbf316 --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; + +/** A score supplier of vectors whose element size is byte. */ +public abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier + implements RandomVectorScorerSupplier { + + final int vectorByteSize, vectorByteOffset; + final int maxOrd; + final MemorySegmentAccessInput input; + final RandomAccessQuantizedByteVectorValues values; // to support ordToDoc/getAcceptOrds + byte[] scratch1, scratch2; + final float constMultiplier; + + /** + * Return an optional whose value, if present, is the scorer supplier. Otherwise, an empty + * optional is returned. + */ + static Optional create( + VectorSimilarityFunction type, + byte bits, + float constMultiplier, + RandomAccessQuantizedByteVectorValues values) { + IndexInput input = values.getSlice(); + if (input == null) { + return Optional.empty(); + } + input = FilterIndexInput.unwrapOnlyTest(input); + if (!(input instanceof MemorySegmentAccessInput msInput)) { + return Optional.empty(); + } + final boolean compressed = values.getVectorByteLength() != values.dimension(); + if (compressed) { + assert bits == 4; + assert values.getVectorByteLength() == values.dimension() / 2; + } + checkInvariants(values.size(), values.getVectorByteLength(), input); + return switch (type) { + case COSINE, DOT_PRODUCT -> { + if (bits == 4) { + yield Optional.of( + new Int4DotProductSupplier(msInput, values, constMultiplier, compressed)); + } + yield Optional.of(new DotProductSupplier(msInput, values, constMultiplier)); + } + case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values, constMultiplier)); + case MAXIMUM_INNER_PRODUCT -> { + if (bits == 4) { + yield Optional.of( + new Int4MaxInnerProductSupplier(msInput, values, constMultiplier, compressed)); + } + yield Optional.of(new MaxInnerProductSupplier(msInput, values, constMultiplier)); + } + }; + } + + Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier) { + this.input = input; + this.values = values; + this.vectorByteSize = values.getVectorByteLength(); + this.vectorByteOffset = values.getVectorByteLength() + Float.BYTES; + this.maxOrd = values.size(); + this.constMultiplier = constMultiplier; + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + static void decompressBytes(byte[] compressed, int numBytes) { + if (numBytes == compressed.length) { + return; + } + if (numBytes << 1 != compressed.length) { + throw new IllegalArgumentException( + "numBytes: " + numBytes + " does not match compressed length: " + compressed.length); + } + for (int i = 0; i < numBytes; ++i) { + compressed[numBytes + i] = (byte) (compressed[i] & 0x0F); + compressed[i] = (byte) ((compressed[i] & 0xFF) >> 4); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + final MemorySegment getFirstSegment(int ord) throws IOException { + long byteOffset = (long) ord * vectorByteOffset; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + // we always read and decompress the full vector if the value is compressed + // Generally, this is OK, as the scorer is used many times after the initial decompression + if (seg == null || values.dimension() != vectorByteSize) { + if (scratch1 == null) { + scratch1 = new byte[values.dimension()]; + } + input.readBytes(byteOffset, scratch1, 0, vectorByteSize); + if (values.dimension() != vectorByteSize) { + decompressBytes(scratch1, vectorByteSize); + } + seg = MemorySegment.ofArray(scratch1); + } + return seg; + } + + final MemorySegment getSecondSegment(int ord) throws IOException { + long byteOffset = (long) ord * vectorByteOffset; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch2 == null) { + scratch2 = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch2, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch2); + } + return seg; + } + + final float getOffsetCorrection(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = ((long) ord * vectorByteOffset) + vectorByteSize; + int floatInts = input.readInt(byteOffset); + return Float.intBitsToFloat(floatInts); + } + + static final class DotProductSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + DotProductSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier) { + super(input, values, constMultiplier); + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); + float offsetCorrection = getOffsetCorrection(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + MemorySegment nodeSegment = getSecondSegment(node); + float dotProduct = PanamaVectorUtilSupport.dotProduct(querySegment, nodeSegment); + float nodeOffsetCorrection = getOffsetCorrection(node); + assert dotProduct >= 0; + float adjustedDistance = + dotProduct * constMultiplier + offsetCorrection + nodeOffsetCorrection; + return Math.max((1 + adjustedDistance) / 2, 0); + } + }; + } + + @Override + public DotProductSupplier copy() throws IOException { + return new DotProductSupplier(input.clone(), values, constMultiplier); + } + } + + static final class Int4DotProductSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + private final boolean compressed; + + Int4DotProductSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + boolean compressed) { + super(input, values, constMultiplier); + this.compressed = compressed; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); + float offsetCorrection = getOffsetCorrection(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + MemorySegment nodeSegment = getSecondSegment(node); + float dotProduct = + PanamaVectorUtilSupport.int4DotProduct(querySegment, false, nodeSegment, compressed); + float nodeOffsetCorrection = getOffsetCorrection(node); + assert dotProduct >= 0; + float adjustedDistance = + dotProduct * constMultiplier + offsetCorrection + nodeOffsetCorrection; + return Math.max((1 + adjustedDistance) / 2, 0); + } + }; + } + + @Override + public Int4DotProductSupplier copy() throws IOException { + return new Int4DotProductSupplier(input.clone(), values, constMultiplier, compressed); + } + } + + static final class EuclideanSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + EuclideanSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier) { + super(input, values, constMultiplier); + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = PanamaVectorUtilSupport.squareDistance(querySegment, getSecondSegment(node)); + float adjustedDistance = raw * constMultiplier; + return 1 / (1f + adjustedDistance); + } + }; + } + + @Override + public EuclideanSupplier copy() throws IOException { + return new EuclideanSupplier(input.clone(), values, constMultiplier); + } + } + + static final class MaxInnerProductSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + MaxInnerProductSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier) { + super(input, values, constMultiplier); + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); + float offsetCorrection = getOffsetCorrection(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + MemorySegment nodeSegment = getSecondSegment(node); + float dotProduct = PanamaVectorUtilSupport.dotProduct(querySegment, nodeSegment); + float nodeOffsetCorrection = getOffsetCorrection(node); + assert dotProduct >= 0; + float adjustedDistance = + dotProduct * constMultiplier + offsetCorrection + nodeOffsetCorrection; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); + } + return adjustedDistance + 1; + } + }; + } + + @Override + public MaxInnerProductSupplier copy() throws IOException { + return new MaxInnerProductSupplier(input.clone(), values, constMultiplier); + } + } + + static final class Int4MaxInnerProductSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + private final boolean compressed; + + Int4MaxInnerProductSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + boolean compressed) { + super(input, values, constMultiplier); + this.compressed = compressed; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); + float offsetCorrection = getOffsetCorrection(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + MemorySegment nodeSegment = getSecondSegment(node); + float dotProduct = + PanamaVectorUtilSupport.int4DotProduct(querySegment, false, nodeSegment, compressed); + float nodeOffsetCorrection = getOffsetCorrection(node); + assert dotProduct >= 0; + float adjustedDistance = + dotProduct * constMultiplier + offsetCorrection + nodeOffsetCorrection; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); + } + return adjustedDistance + 1; + } + }; + } + + @Override + public Int4MaxInnerProductSupplier copy() throws IOException { + return new Int4MaxInnerProductSupplier(input.clone(), values, constMultiplier, compressed); + } + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index ad2dff11cea1..c59912844d8b 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -392,48 +392,55 @@ private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit @Override public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { assert (apacked && bpacked) == false; + return int4DotProduct(MemorySegment.ofArray(a), apacked, MemorySegment.ofArray(b), bpacked); + } + + public static int int4DotProduct( + MemorySegment a, boolean apacked, MemorySegment b, boolean bpacked) { + assert (apacked && bpacked) == false; int i = 0; int res = 0; if (apacked || bpacked) { - byte[] packed = apacked ? a : b; - byte[] unpacked = apacked ? b : a; - if (packed.length >= 32) { + MemorySegment packed = apacked ? a : b; + MemorySegment unpacked = apacked ? b : a; + if (packed.byteSize() >= 32) { if (VECTOR_BITSIZE >= 512) { - i += ByteVector.SPECIES_256.loopBound(packed.length); + i += ByteVector.SPECIES_256.loopBound(packed.byteSize()); res += dotProductBody512Int4Packed(unpacked, packed, i); } else if (VECTOR_BITSIZE == 256) { - i += ByteVector.SPECIES_128.loopBound(packed.length); + i += ByteVector.SPECIES_128.loopBound(packed.byteSize()); res += dotProductBody256Int4Packed(unpacked, packed, i); } else if (PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_64.loopBound(packed.length); + i += ByteVector.SPECIES_64.loopBound(packed.byteSize()); res += dotProductBody128Int4Packed(unpacked, packed, i); } } // scalar tail - for (; i < packed.length; i++) { - byte packedByte = packed[i]; - byte unpacked1 = unpacked[i]; - byte unpacked2 = unpacked[i + packed.length]; + for (; i < packed.byteSize(); i++) { + byte packedByte = packed.get(JAVA_BYTE, i); + byte unpacked1 = unpacked.get(JAVA_BYTE, i); + byte unpacked2 = unpacked.get(JAVA_BYTE, i + packed.byteSize()); res += (packedByte & 0x0F) * unpacked2; res += ((packedByte & 0xFF) >> 4) * unpacked1; } } else { if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { return dotProduct(a, b); - } else if (a.length >= 32 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_128.loopBound(a.length); + } else if (a.byteSize() >= 32 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_128.loopBound(a.byteSize()); res += int4DotProductBody128(a, b, i); } // scalar tail - for (; i < a.length; i++) { - res += b[i] * a[i]; + for (; i < a.byteSize(); i++) { + res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i); } } return res; } - private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody512Int4Packed( + MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 4096) { @@ -442,9 +449,12 @@ private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limi int innerLimit = Math.min(limit - i, 4096); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); + var vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j, LITTLE_ENDIAN); // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + var va8 = + ByteVector.fromMemorySegment( + ByteVector.SPECIES_256, unpacked, i + j + packed.byteSize(), LITTLE_ENDIAN); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -452,7 +462,8 @@ private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limi acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); + ByteVector vc8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j, LITTLE_ENDIAN); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); acc1 = acc1.add(prod16a); @@ -466,7 +477,8 @@ private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limi return sum; } - private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody256Int4Packed( + MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 2048) { @@ -475,9 +487,12 @@ private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limi int innerLimit = Math.min(limit - i, 2048); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); + var vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j, LITTLE_ENDIAN); // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + var va8 = + ByteVector.fromMemorySegment( + ByteVector.SPECIES_128, unpacked, i + j + packed.byteSize(), LITTLE_ENDIAN); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -485,7 +500,8 @@ private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limi acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector vc8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j, LITTLE_ENDIAN); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); acc1 = acc1.add(prod16a); @@ -500,7 +516,8 @@ private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limi } /** vectorized dot product body (128 bit vectors) */ - private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody128Int4Packed( + MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 1024) { @@ -509,10 +526,12 @@ private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limi int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { // packed - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); + ByteVector vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j, LITTLE_ENDIAN); // unpacked ByteVector va8 = - ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); + ByteVector.fromMemorySegment( + ByteVector.SPECIES_64, unpacked, i + j + packed.byteSize(), LITTLE_ENDIAN); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -521,7 +540,7 @@ private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limi acc0 = acc0.add(prod16.and((short) 0xFF)); // lower - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j, LITTLE_ENDIAN); prod8 = vb8.lanewise(LSHR, 4).mul(va8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF)); @@ -535,7 +554,7 @@ private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limi return sum; } - private int int4DotProductBody128(byte[] a, byte[] b, int limit) { + private static int int4DotProductBody128(MemorySegment a, MemorySegment b, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 1024) { @@ -543,15 +562,17 @@ private int int4DotProductBody128(byte[] a, byte[] b, int limit) { ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j); + ByteVector va8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j, LITTLE_ENDIAN); + ByteVector vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j, LITTLE_ENDIAN); ByteVector prod8 = va8.mul(vb8); ShortVector prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc0 = acc0.add(prod16.and((short) 0xFF)); - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8); - vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8, LITTLE_ENDIAN); + vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8, LITTLE_ENDIAN); prod8 = va8.mul(vb8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF)); diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index 0e060586c2a6..d29a75888e9e 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -24,6 +24,7 @@ import java.util.logging.Logger; import jdk.incubator.vector.FloatVector; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.Constants; @@ -86,6 +87,12 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() { return Lucene99MemorySegmentFlatVectorsScorer.INSTANCE; } + @Override + public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return new Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer( + new Lucene99ScalarQuantizedVectorScorer(getLucene99FlatVectorsScorer())); + } + @Override public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException { if (PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS