diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java index a1de3b4081441..89f6545056ad7 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java @@ -82,4 +82,37 @@ Optional getInt7SQVectorScorerSupplier( * @return an optional containing the vector scorer, or empty */ Optional getInt7SQVectorScorer(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector); + + /** + * Returns an optional containing an int7 optimal scalar quantized vector score supplier + * for the given parameters, or an empty optional if a scorer is not supported. + * + * @param similarityType the similarity type + * @param input the index input containing the vector data + * @param values the random access vector values + * @return an optional containing the vector scorer supplier, or empty + */ + Optional getInt7uOSQVectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values + ); + + /** + * Returns an optional containing an int7 optimal scalar quantized vector scorer for + * the given parameters, or an empty optional if a scorer is not supported. + * + * @param sim the similarity type + * @param values the random access vector values + * @return an optional containing the vector scorer, or empty + */ + Optional getInt7uOSQVectorScorer( + VectorSimilarityFunction sim, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index 84227a8907b93..be449bf1ee376 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -60,4 +60,26 @@ public Optional getInt7SQVectorScorer( ) { throw new UnsupportedOperationException("should not reach here"); } + + @Override + public Optional getInt7uOSQVectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values + ) { + throw new UnsupportedOperationException("should not reach here"); + } + + @Override + public Optional getInt7uOSQVectorScorer( + VectorSimilarityFunction sim, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + throw new UnsupportedOperationException("should not reach here"); + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index 2c1996cf1287c..af266d0b1d18b 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -22,6 +22,8 @@ import org.elasticsearch.simdvec.internal.FloatVectorScorerSupplier; import org.elasticsearch.simdvec.internal.Int7SQVectorScorer; import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier; +import org.elasticsearch.simdvec.internal.Int7uOSQVectorScorer; +import org.elasticsearch.simdvec.internal.Int7uOSQVectorScorerSupplier; import java.util.Optional; @@ -90,6 +92,45 @@ public Optional getInt7SQVectorScorer( return Int7SQVectorScorer.create(sim, values, queryVector); } + @Override + public Optional getInt7uOSQVectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values + ) { + input = FilterIndexInput.unwrapOnlyTest(input); + if (input instanceof MemorySegmentAccessInput msInput) { + checkInvariants(values.size(), values.dimension(), input); + return switch (similarityType) { + case COSINE, DOT_PRODUCT -> Optional.of(new Int7uOSQVectorScorerSupplier.DotProductSupplier(msInput, values)); + case EUCLIDEAN -> Optional.of(new Int7uOSQVectorScorerSupplier.EuclideanSupplier(msInput, values)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new Int7uOSQVectorScorerSupplier.MaxInnerProductSupplier(msInput, values)); + }; + } + return Optional.empty(); + } + + @Override + public Optional getInt7uOSQVectorScorer( + VectorSimilarityFunction sim, + org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + return Int7uOSQVectorScorer.create( + sim, + values, + quantizedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ); + } + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { if (input.length() < (long) vectorByteLength * maxOrd) { throw new IllegalArgumentException("input length is less than expected vector data"); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java new file mode 100644 index 0000000000000..17cc0c808ae21 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.util.Optional; + +/** + * Outlines the Int7 OSQ query-time scorer. The concrete implementation will + * connect to the native OSQ routines and apply the similarity-specific + * corrections. + */ +public final class Int7uOSQVectorScorer { + + public static Optional create( + VectorSimilarityFunction sim, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + // TODO add JDK21 fallback logic and native scorer dispatch + return Optional.empty(); + } + + private Int7uOSQVectorScorer() {} +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorerSupplier.java new file mode 100644 index 0000000000000..4ea8f378d45c1 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorerSupplier.java @@ -0,0 +1,334 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +import static org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.SUPPORTS_HEAP_SEGMENTS; +import static org.elasticsearch.simdvec.internal.Similarities.dotProduct7u; +import static org.elasticsearch.simdvec.internal.Similarities.dotProduct7uBulkWithOffsets; + +/** + * Int7 OSQ scorer supplier backed by {@link MemorySegmentAccessInput} storage. + */ +public abstract sealed class Int7uOSQVectorScorerSupplier implements RandomVectorScorerSupplier permits + Int7uOSQVectorScorerSupplier.DotProductSupplier, Int7uOSQVectorScorerSupplier.EuclideanSupplier, + Int7uOSQVectorScorerSupplier.MaxInnerProductSupplier { + + private static final float LIMIT_SCALE = 1f / ((1 << 7) - 1); + + protected final MemorySegmentAccessInput input; + protected final QuantizedByteVectorValues values; + protected final int dims; + protected final int maxOrd; + + Int7uOSQVectorScorerSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values) { + this.input = input; + this.values = values; + this.dims = values.dimension(); + this.maxOrd = values.size(); + } + + protected abstract float applyCorrections(float rawScore, int ord, QueryContext query) throws IOException; + + protected abstract float applyCorrectionsBulk(MemorySegment scores, MemorySegment ordinals, int numNodes, QueryContext query) + throws IOException; + + protected record QueryContext( + int ord, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) {} + + protected QueryContext createQueryContext(int ord) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + return new QueryContext( + ord, + correctiveTerms.lowerInterval(), + correctiveTerms.upperInterval(), + correctiveTerms.additionalCorrection(), + correctiveTerms.quantizedComponentSum() + ); + } + + protected final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + protected final float scoreFromOrds(QueryContext query, int secondOrd) throws IOException { + int firstOrd = query.ord; + checkOrdinal(firstOrd); + checkOrdinal(secondOrd); + long vectorPitch = getVectorPitch(); + long firstVectorOffset = firstOrd * vectorPitch; + long secondVectorOffset = secondOrd * vectorPitch; + + MemorySegment first = input.segmentSliceOrNull(firstVectorOffset, dims); + MemorySegment second = input.segmentSliceOrNull(secondVectorOffset, dims); + if (first == null || second == null) { + return scoreViaFallback(query, secondOrd, firstVectorOffset, secondVectorOffset); + } + int rawScore = dotProduct7u(first, second, dims); + return applyCorrections(rawScore, secondOrd, query); + } + + protected final float bulkScoreFromOrds(QueryContext query, int[] ordinals, float[] scores, int numNodes) throws IOException { + checkOrdinal(query.ord); + MemorySegment vectors = input.segmentSliceOrNull(0, input.length()); + if (vectors == null) { + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + scores[i] = scoreFromOrds(query, ordinals[i]); + max = Math.max(max, scores[i]); + } + return max; + } + if (SUPPORTS_HEAP_SEGMENTS) { + var ordinalsSeg = MemorySegment.ofArray(ordinals); + var scoresSeg = MemorySegment.ofArray(scores); + computeBulkForQuery(query, vectors, ordinalsSeg, scoresSeg, numNodes); + return applyCorrectionsBulk(scoresSeg, ordinalsSeg, numNodes, query); + } else { + try (Arena arena = Arena.ofConfined()) { + MemorySegment ordinalsSeg = arena.allocate((long) numNodes * Integer.BYTES, Integer.BYTES); + MemorySegment scoresSeg = arena.allocate((long) numNodes * Float.BYTES, Float.BYTES); + MemorySegment.copy(ordinals, 0, ordinalsSeg, ValueLayout.JAVA_INT, 0, numNodes); + computeBulkForQuery(query, vectors, ordinalsSeg, scoresSeg, numNodes); + float max = applyCorrectionsBulk(scoresSeg, ordinalsSeg, numNodes, query); + MemorySegment.copy(scoresSeg, ValueLayout.JAVA_FLOAT, 0, scores, 0, numNodes); + return max; + } + } + } + + private void computeBulkForQuery(QueryContext query, MemorySegment vectors, MemorySegment ordinals, MemorySegment scores, int numNodes) + throws IOException { + long firstByteOffset = query.ord * getVectorPitch(); + MemorySegment firstVector = vectors.asSlice(firstByteOffset, getVectorPitch()); + computeBulk(firstVector, vectors, ordinals, scores, numNodes); + } + + private float scoreViaFallback(QueryContext query, int secondOrd, long firstVectorOffset, long secondVectorOffset) throws IOException { + byte[] a = new byte[dims]; + byte[] b = new byte[dims]; + input.readBytes(firstVectorOffset, a, 0, dims); + input.readBytes(secondVectorOffset, b, 0, dims); + // Just fall back to regular dot-product and apply corrections + int raw = VectorUtil.dotProduct(a, b); + return applyCorrections(raw, secondOrd, query); + } + + protected final void computeBulk( + MemorySegment firstVector, + MemorySegment vectors, + MemorySegment ordinals, + MemorySegment scores, + int numNodes + ) throws IOException { + dotProduct7uBulkWithOffsets(vectors, firstVector, dims, (int) getVectorPitch(), ordinals, numNodes, scores); + } + + protected final long getVectorPitch() { + return dims + 3L * Float.BYTES + Integer.BYTES; + } + + @Override + public UpdateableRandomVectorScorer scorer() { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private int ord = -1; + private QueryContext query; + + @Override + public float score(int node) throws IOException { + if (query == null) { + throw new IllegalStateException("scoring ordinal is not set"); + } + return scoreFromOrds(query, node); + } + + @Override + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + if (query == null) { + throw new IllegalStateException("scoring ordinal is not set"); + } + return bulkScoreFromOrds(query, nodes, scores, numNodes); + } + + @Override + public void setScoringOrdinal(int node) throws IOException { + checkOrdinal(node); + ord = node; + query = createQueryContext(node); + } + }; + } + + public QuantizedByteVectorValues get() { + return values; + } + + public static final class DotProductSupplier extends Int7uOSQVectorScorerSupplier { + public DotProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new DotProductSupplier(input.clone(), values.copy()); + } + + @Override + protected float applyCorrections(float rawScore, int ord, QueryContext query) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score += query.additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + return VectorUtil.normalizeToUnitInterval(Math.clamp(score, -1, 1)); + } + + @Override + protected float applyCorrectionsBulk(MemorySegment scoreSeg, MemorySegment ordinalsSeg, int numNodes, QueryContext query) + throws IOException { + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + float raw = scoreSeg.getAtIndex(ValueLayout.JAVA_FLOAT, i); + int ord = ordinalsSeg.getAtIndex(ValueLayout.JAVA_INT, i); + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float adjustedScore = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * raw; + adjustedScore += query.additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + float normalized = VectorUtil.normalizeToUnitInterval(Math.clamp(adjustedScore, -1, 1)); + scoreSeg.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); + } + return max; + } + + } + + public static final class EuclideanSupplier extends Int7uOSQVectorScorerSupplier { + public EuclideanSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new EuclideanSupplier(input.clone(), values.copy()); + } + + @Override + protected float applyCorrections(float rawScore, int ord, QueryContext query) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score = query.additionalCorrection + correctiveTerms.additionalCorrection() - 2 * score; + return VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + } + + @Override + protected float applyCorrectionsBulk(MemorySegment scoreSeg, MemorySegment ordinalsSeg, int numNodes, QueryContext query) + throws IOException { + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + float raw = scoreSeg.getAtIndex(ValueLayout.JAVA_FLOAT, i); + int ord = ordinalsSeg.getAtIndex(ValueLayout.JAVA_INT, i); + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * raw; + score = query.additionalCorrection + correctiveTerms.additionalCorrection() - 2 * score; + float normalized = VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + scoreSeg.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalized); + max = Math.max(max, normalized); + } + return max; + } + } + + public static final class MaxInnerProductSupplier extends Int7uOSQVectorScorerSupplier { + public MaxInnerProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new MaxInnerProductSupplier(input.clone(), values.copy()); + } + + @Override + protected float applyCorrections(float rawScore, int ord, QueryContext query) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score += query.additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + return VectorUtil.scaleMaxInnerProductScore(score); + } + + @Override + protected float applyCorrectionsBulk(MemorySegment scoreSeg, MemorySegment ordinalsSeg, int numNodes, QueryContext query) + throws IOException { + float ay = query.lowerInterval; + float ly = (query.upperInterval - ay) * LIMIT_SCALE; + float y1 = query.quantizedComponentSum; + float max = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + float raw = scoreSeg.getAtIndex(ValueLayout.JAVA_FLOAT, i); + int ord = ordinalsSeg.getAtIndex(ValueLayout.JAVA_INT, i); + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * raw; + score += query.additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + float normalizedScore = VectorUtil.scaleMaxInnerProductScore(score); + scoreSeg.setAtIndex(ValueLayout.JAVA_FLOAT, i, normalizedScore); + max = Math.max(max, normalizedScore); + } + return max; + } + } +} diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java new file mode 100644 index 0000000000000..320d58f0fdbad --- /dev/null +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7uOSQVectorScorer.java @@ -0,0 +1,332 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; + +import static org.elasticsearch.simdvec.internal.Similarities.dotProduct7u; +import static org.elasticsearch.simdvec.internal.Similarities.dotProduct7uBulkWithOffsets; + +/** + * JDK-22+ implementation for Int7 OSQ query-time scorers. + */ +public abstract sealed class Int7uOSQVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer permits + Int7uOSQVectorScorer.DotProductScorer, Int7uOSQVectorScorer.EuclideanScorer, Int7uOSQVectorScorer.MaxInnerProductScorer { + + private static final float LIMIT_SCALE = 1f / ((1 << 7) - 1); + + public static Optional create( + VectorSimilarityFunction sim, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + if (quantizedQuery.length != values.getVectorByteLength()) { + throw new IllegalArgumentException( + "quantized query length " + quantizedQuery.length + " differs from vector byte length " + values.getVectorByteLength() + ); + } + + var input = values.getSlice(); + if (input == null) { + return Optional.empty(); + } + input = FilterIndexInput.unwrapOnlyTest(input); + if ((input instanceof MemorySegmentAccessInput) == false) { + return Optional.empty(); + } + MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; + checkInvariants(values.size(), values.getVectorByteLength(), input); + + return switch (sim) { + case COSINE, DOT_PRODUCT -> Optional.of( + new DotProductScorer( + msInput, + values, + quantizedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ) + ); + case EUCLIDEAN -> Optional.of( + new EuclideanScorer( + msInput, + values, + quantizedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ) + ); + case MAXIMUM_INNER_PRODUCT -> Optional.of( + new MaxInnerProductScorer( + msInput, + values, + quantizedQuery, + lowerInterval, + upperInterval, + additionalCorrection, + quantizedComponentSum + ) + ); + }; + } + + final QuantizedByteVectorValues values; + final MemorySegmentAccessInput input; + final int vectorByteSize; + final MemorySegment query; + final float lowerInterval; + final float upperInterval; + final float additionalCorrection; + final int quantizedComponentSum; + byte[] scratch; + + Int7uOSQVectorScorer( + MemorySegmentAccessInput input, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + super(values); + this.values = values; + this.input = input; + this.vectorByteSize = values.getVectorByteLength(); + this.query = MemorySegment.ofArray(quantizedQuery); + this.lowerInterval = lowerInterval; + this.upperInterval = upperInterval; + this.additionalCorrection = additionalCorrection; + this.quantizedComponentSum = quantizedComponentSum; + } + + abstract float applyCorrections(float rawScore, int ord) throws IOException; + + abstract float applyCorrectionsBulk(float[] scores, int[] ords, int numNodes) throws IOException; + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + int dotProduct = dotProduct7u(query, getSegment(node), vectorByteSize); + return applyCorrections(dotProduct, node); + } + + @Override + public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length()); + if (vectorsSeg == null) { + return super.bulkScore(nodes, scores, numNodes); + } else { + var ordinalsSeg = MemorySegment.ofArray(nodes); + var scoresSeg = MemorySegment.ofArray(scores); + + var vectorPitch = vectorByteSize + 3 * Float.BYTES + Integer.BYTES; + dotProduct7uBulkWithOffsets(vectorsSeg, query, vectorByteSize, vectorPitch, ordinalsSeg, numNodes, scoresSeg); + return applyCorrectionsBulk(scores, nodes, numNodes); + } + } + + final MemorySegment getSegment(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * (vectorByteSize + 3 * Float.BYTES + Integer.BYTES); + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch == null) { + scratch = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch); + } + return seg; + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + public static final class DotProductScorer extends Int7uOSQVectorScorer { + public DotProductScorer( + MemorySegmentAccessInput input, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + super(input, values, quantizedQuery, lowerInterval, upperInterval, additionalCorrection, quantizedComponentSum); + } + + @Override + float applyCorrections(float rawScore, int ord) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float x1 = correctiveTerms.quantizedComponentSum(); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score += additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + score = Math.clamp(score, -1, 1); + return VectorUtil.normalizeToUnitInterval(score); + } + + @Override + float applyCorrectionsBulk(float[] scores, int[] ords, int numNodes) throws IOException { + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + int ord = ords[i]; + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * scores[i]; + score += additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + score = Math.clamp(score, -1, 1); + scores[i] = VectorUtil.normalizeToUnitInterval(score); + if (scores[i] > maxScore) { + maxScore = scores[i]; + } + } + return maxScore; + } + } + + public static final class EuclideanScorer extends Int7uOSQVectorScorer { + public EuclideanScorer( + MemorySegmentAccessInput input, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + super(input, values, quantizedQuery, lowerInterval, upperInterval, additionalCorrection, quantizedComponentSum); + } + + @Override + float applyCorrections(float rawScore, int ord) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float x1 = correctiveTerms.quantizedComponentSum(); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score = additionalCorrection + correctiveTerms.additionalCorrection() - 2 * score; + return VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + } + + @Override + float applyCorrectionsBulk(float[] scores, int[] ords, int numNodes) throws IOException { + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + int ord = ords[i]; + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * scores[i]; + score = additionalCorrection + correctiveTerms.additionalCorrection() - 2 * score; + scores[i] = VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + if (scores[i] > maxScore) { + maxScore = scores[i]; + } + } + return maxScore; + } + } + + public static final class MaxInnerProductScorer extends Int7uOSQVectorScorer { + public MaxInnerProductScorer( + MemorySegmentAccessInput input, + QuantizedByteVectorValues values, + byte[] quantizedQuery, + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum + ) { + super(input, values, quantizedQuery, lowerInterval, upperInterval, additionalCorrection, quantizedComponentSum); + } + + @Override + float applyCorrections(float rawScore, int ord) throws IOException { + var correctiveTerms = values.getCorrectiveTerms(ord); + float x1 = correctiveTerms.quantizedComponentSum(); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * rawScore; + score += additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + return VectorUtil.scaleMaxInnerProductScore(score); + } + + @Override + float applyCorrectionsBulk(float[] scores, int[] ords, int numNodes) throws IOException { + float ay = lowerInterval; + float ly = (upperInterval - ay) * LIMIT_SCALE; + float y1 = quantizedComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < numNodes; i++) { + int ord = ords[i]; + var correctiveTerms = values.getCorrectiveTerms(ord); + float ax = correctiveTerms.lowerInterval(); + float lx = (correctiveTerms.upperInterval() - ax) * LIMIT_SCALE; + float x1 = correctiveTerms.quantizedComponentSum(); + float score = ax * ay * values.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * scores[i]; + score += additionalCorrection + correctiveTerms.additionalCorrection() - values.getCentroidDP(); + scores[i] = VectorUtil.scaleMaxInnerProductScore(score); + if (scores[i] > maxScore) { + maxScore = scores[i]; + } + } + return maxScore; + } + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + long vectorPitch = vectorByteLength + 3L * Float.BYTES + Integer.BYTES; + if (input.length() < vectorPitch * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } +} diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java index 3905c462d29ca..259e5bafb765e 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java @@ -23,11 +23,11 @@ public abstract class AbstractVectorTestCase extends ESTestCase { - static Optional factory; + static Optional factory; @BeforeClass public static void getVectorScorerFactory() { - factory = VectorScorerFactory.instance(); + factory = org.elasticsearch.simdvec.VectorScorerFactory.instance(); } protected AbstractVectorTestCase() { diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java new file mode 100644 index 0000000000000..86991e67b1f89 --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7uOSQVectorScorerFactoryTests.java @@ -0,0 +1,894 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec; + +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.core.SuppressForbidden; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.elasticsearch.simdvec.VectorSimilarityType.DOT_PRODUCT; +import static org.elasticsearch.simdvec.VectorSimilarityType.EUCLIDEAN; +import static org.elasticsearch.simdvec.VectorSimilarityType.MAXIMUM_INNER_PRODUCT; +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.hamcrest.Matchers.equalTo; + +public class Int7uOSQVectorScorerFactoryTests extends org.elasticsearch.simdvec.AbstractVectorTestCase { + private static final float LIMIT_SCALE = 1f / ((1 << 7) - 1); + + @SuppressForbidden(reason = "require usage of OptimizedScalarQuantizer") + private static OptimizedScalarQuantizer scalarQuantizer(VectorSimilarityFunction sim) { + return new OptimizedScalarQuantizer(sim); + } + + // bounds of the range of values that can be seen by int7 scalar quantized vectors + static final byte MIN_INT7_VALUE = 0; + static final byte MAX_INT7_VALUE = 127; + + // Tests that the provider instance is present or not on expected platforms/architectures + public void testSupport() { + supported(); + } + + public void testSimple() throws IOException { + testSimpleImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE); + } + + public void testSimpleMaxChunkSizeSmall() throws IOException { + long maxChunkSize = randomLongBetween(4, 16); + logger.info("maxChunkSize=" + maxChunkSize); + testSimpleImpl(maxChunkSize); + } + + void testSimpleImpl(long maxChunkSize) throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testSimpleImpl"), maxChunkSize)) { + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var scalarQuantizer = scalarQuantizer(sim.function()); + for (int dims : List.of(31, 32, 33)) { + // dimensions that cross the scalar / native boundary (stride) + byte[] vec1 = new byte[dims]; + byte[] vec2 = new byte[dims]; + float[] query1 = new float[dims]; + float[] query2 = new float[dims]; + float[] centroid = new float[dims]; + float centroidDP = 0f; + OptimizedScalarQuantizer.QuantizationResult vec1Correction, vec2Correction; + String fileName = "testSimpleImpl-" + sim + "-" + dims + ".vex"; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < dims; i++) { + query1[i] = (float) i; + query2[i] = (float) (dims - i); + centroid[i] = (query1[i] + query2[i]) / 2f; + centroidDP += centroid[i] * centroid[i]; + } + vec1Correction = scalarQuantizer.scalarQuantize(query1, vec1, (byte) 7, centroid); + vec2Correction = scalarQuantizer.scalarQuantize(query2, vec2, (byte) 7, centroid); + out.writeBytes(vec1, 0, vec1.length); + out.writeInt(Float.floatToIntBits(vec1Correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(vec1Correction.upperInterval())); + out.writeInt(Float.floatToIntBits(vec1Correction.additionalCorrection())); + out.writeInt(vec1Correction.quantizedComponentSum()); + out.writeBytes(vec2, 0, vec2.length); + out.writeInt(Float.floatToIntBits(vec2Correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(vec2Correction.upperInterval())); + out.writeInt(Float.floatToIntBits(vec2Correction.additionalCorrection())); + out.writeInt(vec2Correction.quantizedComponentSum()); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(dims, 2, centroid, centroidDP, in, sim.function()); + float expected = luceneScore(sim, vec1, vec2, centroidDP, vec1Correction, vec2Correction); + + var luceneSupplier = luceneScoreSupplier(values, sim.function()).scorer(); + luceneSupplier.setScoringOrdinal(1); + assertFloatEquals(expected, luceneSupplier.score(0), 1e-6f); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(1); + assertFloatEquals(expected, scorer.score(0), 1e-6f); + + if (supportsHeapSegments()) { + var qScorer = factory.getInt7uOSQVectorScorer( + sim.function(), + values, + vec2, + vec2Correction.lowerInterval(), + vec2Correction.upperInterval(), + vec2Correction.additionalCorrection(), + vec2Correction.quantizedComponentSum() + ).get(); + assertFloatEquals(expected, qScorer.score(0), 1e-6f); + } + } + } + } + } + } + + public void testRandom() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC); + } + + public void testRandomMaxChunkSizeSmall() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + long maxChunkSize = randomLongBetween(32, 128); + logger.info("maxChunkSize=" + maxChunkSize); + testRandomSupplier(maxChunkSize, BYTE_ARRAY_RANDOM_INT7_FUNC); + } + + public void testRandomMax() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_INT7_FUNC); + } + + public void testRandomMin() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_INT7_FUNC); + } + + void testRandomSupplier(long maxChunkSize, Function byteArraySupplier) throws IOException { + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) { + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] quantizationResults = new OptimizedScalarQuantizer.QuantizationResult[size]; + final float[] centroid = new float[dims]; + + String fileName = "testRandom-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = byteArraySupplier.apply(dims); + int componentSum = 0; + for (int d = 0; d < dims; d++) { + componentSum += Byte.toUnsignedInt(vec[d]); + } + float lowerInterval = randomFloat(); + float upperInterval = randomFloat() + lowerInterval; + quantizationResults[i] = new OptimizedScalarQuantizer.QuantizationResult( + lowerInterval, + upperInterval, + randomFloat(), + componentSum + ); + out.writeBytes(vec, 0, vec.length); + out.writeInt(Float.floatToIntBits(lowerInterval)); + out.writeInt(Float.floatToIntBits(upperInterval)); + out.writeInt(Float.floatToIntBits(quantizationResults[i].additionalCorrection())); + out.writeInt(componentSum); + vectors[i] = vec; + } + } + for (int i = 0; i < dims; i++) { + centroid[i] = randomFloat(); + } + float centroidDP = VectorUtil.dotProduct(centroid, centroid); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float expected = luceneScore( + sim, + vectors[idx0], + vectors[idx1], + centroidDP, + quantizationResults[idx0], + quantizationResults[idx1] + ); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + } + } + } + } + } + + public void testRandomScorer() throws IOException { + testRandomScorerImpl( + MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, + org.elasticsearch.simdvec.Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_RANDOM_FUNC + ); + } + + public void testRandomScorerMax() throws IOException { + testRandomScorerImpl( + MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, + org.elasticsearch.simdvec.Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_MAX_FUNC + ); + } + + public void testRandomScorerChunkSizeSmall() throws IOException { + long maxChunkSize = randomLongBetween(32, 128); + logger.info("maxChunkSize=" + maxChunkSize); + testRandomScorerImpl(maxChunkSize, FLOAT_ARRAY_RANDOM_FUNC); + } + + void testRandomScorerImpl(long maxChunkSize, Function floatArraySupplier) throws IOException { + assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22); + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) { + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var scalarQuantizer = new OptimizedScalarQuantizer(sim.function()); + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final float[] centroid = new float[dims]; + for (int i = 0; i < dims; i++) { + centroid[i] = randomFloat(); + } + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final float[][] vectors = new float[size][]; + final byte[][] qVectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testRandom-" + sim + "-" + dims + ".vex"; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + vectors[i] = floatArraySupplier.apply(dims); + qVectors[i] = new byte[dims]; + corrections[i] = scalarQuantizer.scalarQuantize(vectors[i], qVectors[i], (byte) 7, centroid); + out.writeBytes(qVectors[i], 0, qVectors[i].length); + out.writeInt(Float.floatToIntBits(corrections[i].lowerInterval())); + out.writeInt(Float.floatToIntBits(corrections[i].upperInterval())); + out.writeInt(Float.floatToIntBits(corrections[i].additionalCorrection())); + out.writeInt(corrections[i].quantizedComponentSum()); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + + var expected = luceneScore(sim, qVectors[idx0], qVectors[idx1], centroidDP, corrections[idx0], corrections[idx1]); + var scorer = factory.getInt7uOSQVectorScorer( + sim.function(), + values, + qVectors[idx0], + corrections[idx0].lowerInterval(), + corrections[idx0].upperInterval(), + corrections[idx0].additionalCorrection(), + corrections[idx0].quantizedComponentSum() + ).get(); + assertFloatEquals(expected, scorer.score(idx1), 1e-6f); + } + } + } + } + } + + public void testRandomSlice() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSliceImpl(30, 64, 1, BYTE_ARRAY_RANDOM_INT7_FUNC); + } + + void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding, Function byteArraySupplier) + throws IOException { + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandomSliceImpl"), maxChunkSize)) { + for (int times = 0; times < TIMES; times++) { + final int size = randomIntBetween(2, 100); + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testRandomSliceImpl-" + times + "-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + byte[] ba = new byte[initialPadding]; + out.writeBytes(ba, 0, ba.length); + for (int i = 0; i < size; i++) { + var vec = byteArraySupplier.apply(dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + vectors[i] = vec; + corrections[i] = correction; + } + } + try ( + var outter = dir.openInput(fileName, IOContext.DEFAULT); + var in = outter.slice("slice", initialPadding, outter.length() - initialPadding) + ) { + for (int itrs = 0; itrs < TIMES / 10; itrs++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float expected = luceneScore( + sim, + vectors[idx0], + vectors[idx1], + centroidDP, + corrections[idx0], + corrections[idx1] + ); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + } + } + } + } + } + } + + // Tests with a large amount of data (> 2GB), which ensures that data offsets do not overflow + @Nightly + public void testLarge() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testLarge"))) { + final int dims = 8192; + final int size = 262144; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testLarge-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + corrections[i] = correction; + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = size - 1; + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float expected = luceneScore( + sim, + vector(idx0, dims), + vector(idx1, dims), + centroidDP, + corrections[idx0], + corrections[idx1] + ); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + } + } + } + } + } + + // Test that the scorer works well when the IndexInput is greater than the directory segment chunk size + public void testDatasetGreaterThanChunkSize() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testDatasetGreaterThanChunkSize"), 8192)) { + final int dims = 1024; + final int size = 128; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + + String fileName = "testDatasetGreaterThanChunkSize-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + vectors[i] = vec; + corrections[i] = correction; + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = size - 1; + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float expected = luceneScore(sim, vectors[idx0], vectors[idx1], centroidDP, corrections[idx0], corrections[idx1]); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx1); + assertFloatEquals(expected, scorer.score(idx0), 1e-6f); + } + } + } + } + } + + public void testBulk() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + final int dims = 1024; + final int size = randomIntBetween(1, 102); + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + // Set maxChunkSize to be less than dims * size + try (Directory dir = new MMapDirectory(createTempDir("testBulk"))) { + String fileName = "testBulk-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + vectors[i] = vec; + corrections[i] = correction; + } + } + + List ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); + for (var sim : List.of(EUCLIDEAN)) { + QuantizedByteVectorValues values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float[] expected = new float[nodes.length]; + float[] scores = new float[nodes.length]; + var referenceScorer = luceneScoreSupplier(values, sim.function()).scorer(); + referenceScorer.setScoringOrdinal(idx0); + referenceScorer.bulkScore(nodes, expected, nodes.length); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).orElseThrow(); + var testScorer = supplier.scorer(); + testScorer.setScoringOrdinal(idx0); + testScorer.bulkScore(nodes, scores, nodes.length); + // applying the corrections in even a slightly different order can impact the score + // account for this during bulk scoring + assertFloatArrayEquals(expected, scores, 2e-5f); + } + } + } + } + } + + public void testBulkWithDatasetGreaterThanChunkSize() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + final int dims = 1024; + final int size = 128; + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + final byte[][] vectors = new byte[size][]; + final OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[size]; + // Set maxChunkSize to be less than dims * size + try (Directory dir = new MMapDirectory(createTempDir("testBulkWithDatasetGreaterThanChunkSize"), 8192)) { + String fileName = "testBulkWithDatasetGreaterThanChunkSize-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + var correction = randomCorrection(vec); + writeVectorWithCorrection(out, vec, correction); + vectors[i] = vec; + corrections[i] = correction; + } + } + + List ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); + for (var sim : List.of(DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + QuantizedByteVectorValues values = vectorValues(dims, size, centroid, centroidDP, in, sim.function()); + float[] expected = new float[nodes.length]; + float[] scores = new float[nodes.length]; + var referenceScorer = luceneScoreSupplier(values, sim.function()).scorer(); + referenceScorer.setScoringOrdinal(idx0); + referenceScorer.bulkScore(nodes, expected, nodes.length); + var supplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).orElseThrow(); + var testScorer = supplier.scorer(); + testScorer.setScoringOrdinal(idx0); + testScorer.bulkScore(nodes, scores, nodes.length); + assertFloatArrayEquals(expected, scores, 1e-6f); + } + } + } + } + } + + public void testRace() throws Exception { + testRaceImpl(DOT_PRODUCT); + testRaceImpl(EUCLIDEAN); + testRaceImpl(MAXIMUM_INNER_PRODUCT); + } + + // Tests that copies in threads do not interfere with each other + void testRaceImpl(org.elasticsearch.simdvec.VectorSimilarityType sim) throws Exception { + assumeTrue(notSupportedMsg(), supported()); + var factory = org.elasticsearch.simdvec.AbstractVectorTestCase.factory.get(); + + final long maxChunkSize = 32; + final int dims = 34; // dimensions that are larger than the chunk size, to force fallback + final float[] centroid = FLOAT_ARRAY_RANDOM_FUNC.apply(dims); + final float centroidDP = VectorUtil.dotProduct(centroid, centroid); + byte[] vec1 = new byte[dims]; + byte[] vec2 = new byte[dims]; + IntStream.range(0, dims).forEach(i -> vec1[i] = 1); + IntStream.range(0, dims).forEach(i -> vec2[i] = 2); + var correction1 = randomCorrection(vec1); + var correction2 = randomCorrection(vec2); + try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) { + String fileName = "testRace-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + writeVectorWithCorrection(out, vec1, correction1); + writeVectorWithCorrection(out, vec1, correction1); + writeVectorWithCorrection(out, vec2, correction2); + writeVectorWithCorrection(out, vec2, correction2); + } + var expectedScore1 = luceneScore(sim, vec1, vec1, centroidDP, correction1, correction1); + var expectedScore2 = luceneScore(sim, vec2, vec2, centroidDP, correction2, correction2); + + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(dims, 4, centroid, centroidDP, in, sim.function()); + var scoreSupplier = factory.getInt7uOSQVectorScorerSupplier(sim, in, values).get(); + var tasks = List.>>of( + new ScoreCallable(scoreSupplier.copy().scorer(), 0, 1, expectedScore1), + new ScoreCallable(scoreSupplier.copy().scorer(), 2, 3, expectedScore2) + ); + var executor = Executors.newFixedThreadPool(2); + var results = executor.invokeAll(tasks); + executor.shutdown(); + assertTrue(executor.awaitTermination(60, TimeUnit.SECONDS)); + assertThat(results.stream().filter(Predicate.not(Future::isDone)).count(), equalTo(0L)); + for (var res : results) { + assertThat("Unexpected exception" + res.get(), res.get(), isEmpty()); + } + } + } + } + + static class ScoreCallable implements Callable> { + + final UpdateableRandomVectorScorer scorer; + final int ord; + final float expectedScore; + + ScoreCallable(UpdateableRandomVectorScorer scorer, int queryOrd, int ord, float expectedScore) { + try { + this.scorer = scorer; + this.scorer.setScoringOrdinal(queryOrd); + this.ord = ord; + this.expectedScore = expectedScore; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Optional call() { + try { + for (int i = 0; i < 100; i++) { + assertFloatEquals(expectedScore, scorer.score(ord), 1e-6f); + } + } catch (Throwable t) { + return Optional.of(t); + } + return Optional.empty(); + } + } + + private static OptimizedScalarQuantizer.QuantizationResult randomCorrection(byte[] vec) { + int componentSum = 0; + for (byte value : vec) { + componentSum += Byte.toUnsignedInt(value); + } + float lowerInterval = randomFloat(); + float upperInterval = lowerInterval + randomFloat(); + return new OptimizedScalarQuantizer.QuantizationResult(lowerInterval, upperInterval, randomFloat(), componentSum); + } + + private static void writeVectorWithCorrection(IndexOutput out, byte[] vec, OptimizedScalarQuantizer.QuantizationResult correction) + throws IOException { + out.writeBytes(vec, 0, vec.length); + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + out.writeInt(correction.quantizedComponentSum()); + } + + QuantizedByteVectorValues vectorValues( + int dims, + int size, + float[] centroid, + float centroidDP, + IndexInput in, + VectorSimilarityFunction sim + ) throws IOException { + var slice = in.slice("values", 0, in.length()); + return new DenseOffHeapScalarQuantizedVectorValues(dims, size, sim, slice, centroid, centroidDP); + } + + /** Computes the score using the Lucene implementation. */ + public float luceneScore( + org.elasticsearch.simdvec.VectorSimilarityType similarityFunc, + byte[] a, + byte[] b, + float centroidDP, + OptimizedScalarQuantizer.QuantizationResult aCorrection, + OptimizedScalarQuantizer.QuantizationResult bCorrection + ) { + OSQScorer scorer = OSQScorer.fromSimilarity(similarityFunc); + return scorer.score(a, b, centroidDP, aCorrection, bCorrection); + } + + private abstract static class OSQScorer { + static OSQScorer fromSimilarity(org.elasticsearch.simdvec.VectorSimilarityType sim) { + return switch (sim) { + case DOT_PRODUCT -> new DotProductOSQScorer(); + case MAXIMUM_INNER_PRODUCT -> new MaxInnerProductOSQScorer(); + case EUCLIDEAN -> new EuclideanOSQScorer(); + default -> throw new IllegalArgumentException("Unsupported similarity: " + sim); + }; + } + + final float score( + byte[] a, + byte[] b, + float centroidDP, + OptimizedScalarQuantizer.QuantizationResult aCorrection, + OptimizedScalarQuantizer.QuantizationResult bCorrection + ) { + float ax = aCorrection.lowerInterval(); + float lx = (aCorrection.upperInterval() - ax) * LIMIT_SCALE; + float ay = bCorrection.lowerInterval(); + float ly = (bCorrection.upperInterval() - ay) * LIMIT_SCALE; + float y1 = bCorrection.quantizedComponentSum(); + float x1 = aCorrection.quantizedComponentSum(); + float score = ax * ay * a.length + ay * lx * x1 + ax * ly * y1 + lx * ly * VectorUtil.dotProduct(a, b); + return scaleScore(score, aCorrection.additionalCorrection(), bCorrection.additionalCorrection(), centroidDP); + } + + abstract float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP); + + private static class DotProductOSQScorer extends OSQScorer { + @Override + float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP) { + score += aCorrection + bCorrection - centroidDP; + score = Math.clamp(score, -1, 1); + return VectorUtil.normalizeToUnitInterval(score); + } + } + + private static class MaxInnerProductOSQScorer extends OSQScorer { + @Override + float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP) { + score += aCorrection + bCorrection - centroidDP; + return VectorUtil.scaleMaxInnerProductScore(score); + } + } + + private static class EuclideanOSQScorer extends OSQScorer { + @Override + float scaleScore(float score, float aCorrection, float bCorrection, float centroidDP) { + score = aCorrection + bCorrection - 2 * score; + return VectorUtil.normalizeDistanceToUnitInterval(Math.max(score, 0f)); + } + } + } + + static void assertFloatArrayEquals(float[] expected, float[] actual, float delta) { + assertThat(actual.length, equalTo(expected.length)); + for (int i = 0; i < expected.length; i++) { + assertEquals("differed at element [" + i + "]", expected[i], actual[i], Math.abs(expected[i]) * delta + delta); + } + } + + static void assertFloatEquals(float expected, float actual, float delta) { + assertEquals(expected, actual, Math.abs(expected) * delta + delta); + } + + static RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) + throws IOException { + return new Lucene104ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values); + } + + // creates the vector based on the given ordinal, which is reproducible given the ord and dims + static byte[] vector(int ord, int dims) { + var random = new Random(Objects.hash(ord, dims)); + byte[] ba = new byte[dims]; + for (int i = 0; i < dims; i++) { + ba[i] = (byte) RandomNumbers.randomIntBetween(random, MIN_INT7_VALUE, MAX_INT7_VALUE); + } + return ba; + } + + static Function FLOAT_ARRAY_RANDOM_FUNC = size -> { + float[] fa = new float[size]; + for (int i = 0; i < size; i++) { + fa[i] = randomFloat(); + } + return fa; + }; + + static Function BYTE_ARRAY_RANDOM_INT7_FUNC = size -> { + byte[] ba = new byte[size]; + randomBytesBetween(ba, MIN_INT7_VALUE, MAX_INT7_VALUE); + return ba; + }; + + static Function BYTE_ARRAY_MAX_INT7_FUNC = size -> { + byte[] ba = new byte[size]; + Arrays.fill(ba, MAX_INT7_VALUE); + return ba; + }; + + static Function BYTE_ARRAY_MIN_INT7_FUNC = size -> { + byte[] ba = new byte[size]; + Arrays.fill(ba, MIN_INT7_VALUE); + return ba; + }; + + static final int TIMES = 100; // a loop iteration times + + static class DenseOffHeapScalarQuantizedVectorValues extends QuantizedByteVectorValues { + final int dimension; + final int size; + final VectorSimilarityFunction similarityFunction; + + final IndexInput slice; + final byte[] vectorValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final float[] centroid; + final float centroidDp; + + DenseOffHeapScalarQuantizedVectorValues( + int dimension, + int size, + VectorSimilarityFunction similarityFunction, + IndexInput slice, + float[] centroid, + float centroidDp + ) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.correctiveValues = new float[3]; + this.byteSize = dimension + (Float.BYTES * 3) + Integer.BYTES; + this.byteBuffer = ByteBuffer.allocate(dimension); + this.vectorValue = byteBuffer.array(); + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) throws IOException { + if (lastOrd != vectorOrd) { + slice.seek((long) vectorOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + lastOrd = vectorOrd; + } + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + return scalarQuantizer(similarityFunction); + } + + @Override + public Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding getScalarEncoding() { + return Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT; + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public float getCentroidDP() throws IOException { + return centroidDp; + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + assert false; + return null; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (lastOrd == ord) { + return vectorValue; + } + slice.seek((long) ord * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + lastOrd = ord; + return vectorValue; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public QuantizedByteVectorValues copy() throws IOException { + return new DenseOffHeapScalarQuantizedVectorValues(dimension, size, similarityFunction, slice.clone(), centroid, centroidDp); + } + } +} diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index 3b562b0dda90d..41f18f90a0f90 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -39,9 +39,9 @@ import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; -import org.elasticsearch.index.codec.vectors.es93.ES93ScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es94.ES94HnswScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es94.ES94ScalarQuantizedVectorsFormat; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.logging.Level; import org.elasticsearch.logging.LogManager; @@ -189,15 +189,13 @@ static Codec createCodec(TestConfiguration args, @Nullable ExecutorService exec) } } else if (quantizeBits < 32) { if (args.indexType() == IndexType.FLAT) { - format = new ES93ScalarQuantizedVectorsFormat(elementType, null, quantizeBits, true, false); + format = new ES94ScalarQuantizedVectorsFormat(elementType, quantizeBits, false); } else { - format = new ES93HnswScalarQuantizedVectorsFormat( + format = new ES94HnswScalarQuantizedVectorsFormat( args.hnswM(), args.hnswEfConstruction(), elementType, - null, quantizeBits, - true, false, exec != null ? args.numMergeWorkers() : 1, exec diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 7754fc1e982f4..600e53a54a01b 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -475,7 +475,9 @@ org.elasticsearch.index.codec.vectors.es93.ES93ScalarQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat, - org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; + org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es94.ES94HnswScalarQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es94.ES94ScalarQuantizedVectorsFormat; provides org.apache.lucene.codecs.Codec with @@ -510,6 +512,7 @@ exports org.elasticsearch.index.codec.vectors.diskbbq.next to org.elasticsearch.test.knn, org.elasticsearch.xpack.diskbbq; exports org.elasticsearch.index.codec.vectors.cluster to org.elasticsearch.test.knn; exports org.elasticsearch.index.codec.vectors.es93 to org.elasticsearch.test.knn; + exports org.elasticsearch.index.codec.vectors.es94 to org.elasticsearch.test.knn; exports org.elasticsearch.search.crossproject; exports org.elasticsearch.index.mapper.blockloader; exports org.elasticsearch.index.mapper.blockloader.docvalues; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..96615e922419d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormat.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +public class ES94HnswScalarQuantizedVectorsFormat extends AbstractHnswVectorsFormat { + + static final String NAME = "ES94HnswScalarQuantizedVectorsFormat"; + + private final FlatVectorsFormat flatVectorFormat; + + public ES94HnswScalarQuantizedVectorsFormat() { + super(NAME); + flatVectorFormat = new ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT); + } + + public ES94HnswScalarQuantizedVectorsFormat( + int maxConn, + int beamWidth, + DenseVectorFieldMapper.ElementType elementType, + int bits, + boolean useDirectIO + ) { + super(NAME, maxConn, beamWidth); + flatVectorFormat = new ES94ScalarQuantizedVectorsFormat(elementType, bits, useDirectIO); + } + + public ES94HnswScalarQuantizedVectorsFormat( + int maxConn, + int beamWidth, + DenseVectorFieldMapper.ElementType elementType, + int bits, + boolean useDirectIO, + int numMergeWorkers, + ExecutorService mergeExec + ) { + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + flatVectorFormat = new ES94ScalarQuantizedVectorsFormat(elementType, bits, useDirectIO); + } + + @Override + protected FlatVectorsFormat flatVectorsFormat() { + return flatVectorFormat; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + 0 + ); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorFormat.fieldsReader(state)); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..7caf2a294e59e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormat.java @@ -0,0 +1,190 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsReader; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsWriter; +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.es93.ES93FlatVectorScorer; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.simdvec.VectorScorerFactory; +import org.elasticsearch.simdvec.VectorSimilarityType; + +import java.io.IOException; + +public class ES94ScalarQuantizedVectorsFormat extends FlatVectorsFormat { + + static final String NAME = "ES94ScalarQuantizedVectorsFormat"; + private static final int ALLOWED_BITS = (1 << 7) | (1 << 4) | (1 << 2) | (1 << 1); + + static final Lucene104ScalarQuantizedVectorScorer flatVectorScorer = new ESQuantizedFlatVectorsScorer(ES93FlatVectorScorer.INSTANCE); + private final FlatVectorsFormat rawVectorFormat; + private final Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding encoding; + + public ES94ScalarQuantizedVectorsFormat() { + this(DenseVectorFieldMapper.ElementType.FLOAT, 7, false); + } + + public ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType) { + this(elementType, 7, false); + } + + public ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType, int bits, boolean useDirectIO) { + super(NAME); + if (bits < 1 || bits > 8 || (ALLOWED_BITS & (1 << bits)) == 0) { + throw new IllegalArgumentException("bits must be one of: 1, 2, 4, 7; bits=" + bits); + } + assert elementType != DenseVectorFieldMapper.ElementType.BIT : "BIT should not be used with scalar quantization"; + + this.rawVectorFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); + this.encoding = switch (bits) { + case 1 -> Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE; + case 2 -> Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.DIBIT_QUERY_NIBBLE; + case 4 -> Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.PACKED_NIBBLE; + case 7 -> Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT; + default -> throw new IllegalArgumentException("bits must be one of: 1, 2, 4, 7; bits=" + bits); + }; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene104ScalarQuantizedVectorsWriter(state, encoding, rawVectorFormat.fieldsWriter(state), flatVectorScorer); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene104ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), flatVectorScorer); + } + + @Override + public String toString() { + return NAME + + "(name=" + + NAME + + ", encoding=" + + encoding + + ", flatVectorScorer=" + + flatVectorScorer + + ", rawVectorFormat=" + + rawVectorFormat + + ")"; + } + + static final class ESQuantizedFlatVectorsScorer extends Lucene104ScalarQuantizedVectorScorer { + + final FlatVectorsScorer delegate; + final VectorScorerFactory factory; + + ESQuantizedFlatVectorsScorer(FlatVectorsScorer delegate) { + super(delegate); + this.delegate = delegate; + factory = VectorScorerFactory.instance().orElse(null); + } + + @Override + public String toString() { + return "ESQuantizedFlatVectorsScorer(" + "delegate=" + delegate + ", factory=" + factory + ')'; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction sim, KnnVectorValues values) + throws IOException { + if (values instanceof QuantizedByteVectorValues quantizedValues && quantizedValues.getSlice() != null) { + // TODO: optimize int4, 2, and single bit quantization + if (quantizedValues.getScalarEncoding() != Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT) { + return super.getRandomVectorScorerSupplier(sim, values); + } + if (factory != null) { + var scorer = factory.getInt7uOSQVectorScorerSupplier( + VectorSimilarityType.of(sim), + quantizedValues.getSlice(), + quantizedValues + ); + if (scorer.isPresent()) { + return scorer.get(); + } + } + } + return super.getRandomVectorScorerSupplier(sim, values); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, float[] query) + throws IOException { + if (values instanceof QuantizedByteVectorValues quantizedValues && quantizedValues.getSlice() != null) { + // TODO: optimize int4 quantization + if (quantizedValues.getScalarEncoding() != Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT) { + return super.getRandomVectorScorer(sim, values, query); + } + if (factory != null) { + OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(sim); + float[] residualScratch = new float[query.length]; + int[] quantizedQuery = new int[query.length]; + var correctiveComponents = scalarQuantizer.scalarQuantize( + query, + residualScratch, + quantizedQuery, + quantizedValues.getScalarEncoding().getQueryBits(), + quantizedValues.getCentroid() + ); + byte[] quantizedQueryBytes = new byte[quantizedQuery.length]; + for (int i = 0; i < quantizedQuery.length; i++) { + quantizedQueryBytes[i] = (byte) quantizedQuery[i]; + } + + var scorer = factory.getInt7uOSQVectorScorer( + sim, + quantizedValues, + quantizedQueryBytes, + correctiveComponents.lowerInterval(), + correctiveComponents.upperInterval(), + correctiveComponents.additionalCorrection(), + correctiveComponents.quantizedComponentSum() + ); + if (scorer.isPresent()) { + return scorer.get(); + } + } + } + return super.getRandomVectorScorer(sim, values, query); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, byte[] query) + throws IOException { + return super.getRandomVectorScorer(sim, values, query); + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues scoringVectors, + QuantizedByteVectorValues targetVectors + ) { + // TODO improve merge-times for HNSW through off-heap optimized search + return super.getRandomVectorScorerSupplier(similarityFunction, scoringVectors, targetVectors); + } + + } +} diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 0dc34ea2e808d..aaf0115b99ee5 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -15,3 +15,5 @@ org.elasticsearch.index.codec.vectors.es93.ES93ScalarQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93HnswScalarQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es94.ES94ScalarQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es94.ES94HnswScalarQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..9a4d019db0b91 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.store.Directory; +import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.codec.vectors.BaseHnswBFloat16VectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; +import org.junit.Before; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; + +public class ES94HnswScalarQuantizedBFloat16VectorsFormatTests extends BaseHnswBFloat16VectorsFormatTestCase { + + private int bits; + + @Before + @Override + public void setUp() throws Exception { + bits = randomFrom(1, 2, 4, 7); + super.setUp(); + } + + @Override + protected KnnVectorsFormat createFormat() { + return new ES94HnswScalarQuantizedVectorsFormat( + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + DenseVectorFieldMapper.ElementType.BFLOAT16, + bits, + false + ); + } + + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { + return new ES94HnswScalarQuantizedVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.BFLOAT16, bits, false); + } + + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { + return new ES94HnswScalarQuantizedVectorsFormat( + maxConn, + beamWidth, + DenseVectorFieldMapper.ElementType.BFLOAT16, + bits, + false, + numMergeWorkers, + service + ); + } + + @Override + public void testSingleVectorCase() throws Exception { + throw new AssumptionViolatedException("Scalar quantization changes the score significantly for MAXIMUM_INNER_PRODUCT"); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory()) { + testSimpleOffHeapSize( + dir, + newIndexWriterConfig(), + vector, + allOf( + aMapWithSize(3), + hasEntry("vec", (long) vector.length * BFloat16.BYTES), + hasEntry("vex", 1L), + hasEntry(equalTo("veq"), greaterThan(0L)) + ) + ); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..1245468305300 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94HnswScalarQuantizedVectorsFormatTests.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.store.Directory; +import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; +import org.junit.Before; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasToString; +import static org.hamcrest.Matchers.is; + +public class ES94HnswScalarQuantizedVectorsFormatTests extends BaseHnswVectorsFormatTestCase { + + private int bits; + + @Before + @Override + public void setUp() throws Exception { + bits = randomFrom(1, 2, 4, 7); + super.setUp(); + } + + @Override + protected KnnVectorsFormat createFormat() { + return new ES94HnswScalarQuantizedVectorsFormat( + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + DenseVectorFieldMapper.ElementType.FLOAT, + bits, + false + ); + } + + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { + return new ES94HnswScalarQuantizedVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.FLOAT, bits, false); + } + + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { + return new ES94HnswScalarQuantizedVectorsFormat( + maxConn, + beamWidth, + DenseVectorFieldMapper.ElementType.FLOAT, + bits, + false, + numMergeWorkers, + service + ); + } + + @Override + public void testSingleVectorCase() throws Exception { + throw new AssumptionViolatedException("Scalar quantization changes the score significantly for MAXIMUM_INNER_PRODUCT"); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory()) { + testSimpleOffHeapSize( + dir, + newIndexWriterConfig(), + vector, + allOf( + aMapWithSize(3), + hasEntry("vec", (long) vector.length * Float.BYTES), + hasEntry("vex", 1L), + hasEntry(equalTo("veq"), greaterThan(0L)) + ) + ); + } + } + + public void testToString() { + KnnVectorsFormat format = new ES94HnswScalarQuantizedVectorsFormat(10, 20, DenseVectorFieldMapper.ElementType.FLOAT, 2, false); + assertThat( + format, + hasToString( + is( + "ES94HnswScalarQuantizedVectorsFormat(name=ES94HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, " + + "flatVectorFormat=ES94ScalarQuantizedVectorsFormat(" + + "name=ES94ScalarQuantizedVectorsFormat, encoding=DIBIT_QUERY_NIBBLE, " + + "flatVectorScorer=" + + ES94ScalarQuantizedVectorsFormat.flatVectorScorer + + ", rawVectorFormat=" + + new ES93GenericFlatVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, false) + + "))" + ) + ) + ); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..e65e17ffd48f2 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.codec.vectors.BaseBFloat16KnnVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; +import org.junit.Before; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; + +public class ES94ScalarQuantizedBFloat16VectorsFormatTests extends BaseBFloat16KnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + private KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + int bits = randomFrom(1, 2, 4, 7); + format = new ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16, bits, false); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testSearchWithVisitedLimit() { + throw new AssumptionViolatedException("requires graph vector codec"); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertThat(offHeap, aMapWithSize(2)); + assertThat(offHeap, hasEntry("vec", (long) vector.length * BFloat16.BYTES)); + assertThat(offHeap, hasEntry(equalTo("veq"), greaterThan(0L))); + } + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..6ac9a0e305b84 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es94/ES94ScalarQuantizedVectorsFormatTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es94; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; +import org.junit.Before; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.is; + +public class ES94ScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + + private KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + int bits = randomFrom(1, 2, 4, 7); + format = new ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, bits, false); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testSearchWithVisitedLimit() { + throw new AssumptionViolatedException("requires graph vector codec"); + } + + public void testToString() { + var format = new ES94ScalarQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, 4, false); + String expected = "ES94ScalarQuantizedVectorsFormat(name=ES94ScalarQuantizedVectorsFormat, encoding=PACKED_NIBBLE, " + + "flatVectorScorer=" + + ES94ScalarQuantizedVectorsFormat.flatVectorScorer + + ", rawVectorFormat=" + + new ES93GenericFlatVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, false) + + ")"; + assertThat(format.toString(), is(expected)); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertThat(offHeap, aMapWithSize(2)); + assertThat(offHeap, hasEntry("vec", (long) vector.length * Float.BYTES)); + assertThat(offHeap, hasEntry(equalTo("veq"), greaterThan(0L))); + } + } + } + } +}