diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 6f547169ea88..22ce41f75506 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -95,6 +95,8 @@ New Features * GITHUB#14776: Add a Rescorer that uses values from provided DoubleValuesSource to re-score first pass hits. (Vigya Sharma) +* GITHUB#14708: Add a DoubleValuesSource for full precision vector similarity scores. (Vigya Sharma) + Improvements --------------------- * GITHUB#14458: Add an IndexDeletion policy that retains the last N commits. (Owais Kazi) diff --git a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java index 2650fb164cba..903c1e9a8b6b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java @@ -24,7 +24,6 @@ import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.comparators.DoubleComparator; import org.apache.lucene.util.NumericUtils; @@ -250,14 +249,6 @@ public LongValuesSource rewrite(IndexSearcher searcher) throws IOException { */ public static DoubleValues similarityToQueryVector( LeafReaderContext ctx, byte[] queryVector, String vectorField) throws IOException { - if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding() - != VectorEncoding.BYTE) { - throw new IllegalArgumentException( - "Field " - + vectorField - + " does not have the expected vector encoding: " - + VectorEncoding.BYTE); - } return new ByteVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null); } @@ -273,14 +264,6 @@ public static DoubleValues similarityToQueryVector( */ public static DoubleValues similarityToQueryVector( LeafReaderContext ctx, float[] queryVector, String vectorField) throws IOException { - if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding() - != VectorEncoding.FLOAT32) { - throw new IllegalArgumentException( - "Field " - + vectorField - + " does not have the expected vector encoding: " - + VectorEncoding.FLOAT32); - } return new FloatVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null); } diff --git a/lucene/core/src/java/org/apache/lucene/search/FullPrecisionFloatVectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/FullPrecisionFloatVectorSimilarityValuesSource.java new file mode 100644 index 000000000000..b299fd97fddd --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/FullPrecisionFloatVectorSimilarityValuesSource.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; + +/** + * A {@link DoubleValuesSource} that computes vector similarity between a query vector and raw full + * precision vectors indexed in provided {@link org.apache.lucene.document.KnnFloatVectorField} in + * documents. + */ +public class FullPrecisionFloatVectorSimilarityValuesSource extends DoubleValuesSource { + + private final float[] queryVector; + private final String fieldName; + private VectorSimilarityFunction vectorSimilarityFunction; + + /** + * Creates a {@link DoubleValuesSource} that returns vector similarity score between provided + * query vector and field for documents. + * + * @param vector the query vector + * @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField} + * @param vectorSimilarityFunction the vector similarity function to use + */ + public FullPrecisionFloatVectorSimilarityValuesSource( + float[] vector, String fieldName, VectorSimilarityFunction vectorSimilarityFunction) { + this.queryVector = vector; + this.fieldName = fieldName; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + /** + * Creates a {@link DoubleValuesSource} that returns vector similarity score between provided + * query vector and field for documents. Uses the configured vector similarity function for the + * field. + * + * @param vector the query vector + * @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField} + */ + public FullPrecisionFloatVectorSimilarityValuesSource(float[] vector, String fieldName) { + this(vector, fieldName, null); + } + + /** Sugar to fetch full precision similarity score values */ + public DoubleValues getSimilarityScores(LeafReaderContext ctx) throws IOException { + return getValues(ctx, null); + } + + @Override + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); + if (vectorValues == null) { + FloatVectorValues.checkField(ctx.reader(), fieldName); + return DoubleValues.EMPTY; + } + final FieldInfo fi = ctx.reader().getFieldInfos().fieldInfo(fieldName); + if (fi.getVectorDimension() != queryVector.length) { + throw new IllegalArgumentException( + "Query vector dimension does not match field dimension: " + + queryVector.length + + " != " + + fi.getVectorDimension()); + } + + if (vectorSimilarityFunction == null) { + this.vectorSimilarityFunction = fi.getVectorSimilarityFunction(); + } + final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + return new DoubleValues() { + @Override + public double doubleValue() throws IOException { + return vectorSimilarityFunction.compare( + queryVector, vectorValues.vectorValue(iterator.index())); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc); + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { + return this; + } + + @Override + public int hashCode() { + return Objects.hash(fieldName, Arrays.hashCode(queryVector), vectorSimilarityFunction); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + FullPrecisionFloatVectorSimilarityValuesSource other = + (FullPrecisionFloatVectorSimilarityValuesSource) obj; + return Objects.equals(fieldName, other.fieldName) + && Objects.equals(vectorSimilarityFunction, other.vectorSimilarityFunction) + && Arrays.equals(queryVector, other.queryVector); + } + + @Override + public String toString() { + return "FullPrecisionFloatVectorSimilarityValuesSource(fieldName=" + + fieldName + + " vectorSimilarityFunction=" + + vectorSimilarityFunction.name() + + " queryVector=" + + Arrays.toString(queryVector) + + ")"; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestFullPrecisionFloatVectorSimilarityValuesSource.java b/lucene/core/src/test/org/apache/lucene/search/TestFullPrecisionFloatVectorSimilarityValuesSource.java new file mode 100644 index 000000000000..59df35514f4a --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestFullPrecisionFloatVectorSimilarityValuesSource.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.StoredFields; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.TestVectorUtil; +import org.junit.Before; +import org.junit.Test; + +public class TestFullPrecisionFloatVectorSimilarityValuesSource extends LuceneTestCase { + + private Codec savedCodec; + + private static final String KNN_FIELD = "knnField"; + private static final int NUM_VECTORS = 1000; + private static final int VECTOR_DIMENSION = 128; + + KnnVectorsFormat format; + Float confidenceInterval; + int bits; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + bits = random().nextBoolean() ? 4 : 7; + confidenceInterval = random().nextBoolean() ? random().nextFloat(0.90f, 1.0f) : null; + if (random().nextBoolean()) { + confidenceInterval = 0f; + } + format = getKnnFormat(bits); + savedCodec = Codec.getDefault(); + Codec.setDefault(getCodec()); + } + + @Override + public void tearDown() throws Exception { + Codec.setDefault(savedCodec); // restore + super.tearDown(); + } + + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + private final KnnVectorsFormat getKnnFormat(int bits) { + return new Lucene99HnswScalarQuantizedVectorsFormat( + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + 1, + bits, + bits == 4 ? random().nextBoolean() : false, + confidenceInterval, + null); + } + + @Test + public void testFullPrecisionVectorSimilarityDVS() throws Exception { + List vectors = new ArrayList<>(); + int numVectors = atLeast(NUM_VECTORS); + int numSegments = random().nextInt(2, 10); + final VectorSimilarityFunction indexingSimilarityFunction = + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length)]; + + try (Directory dir = newDirectory()) { + int id = 0; + + // index some 4 bit quantized vectors + try (IndexWriter w = + new IndexWriter( + dir, + newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(4))))) { + for (int j = 0; j < numSegments; j++) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + if (random().nextInt(100) < 30) { + // skip vector for some docs to create sparse vector field + doc.add(new IntField("has_vector", 0, Field.Store.YES)); + } else { + float[] vector = TestVectorUtil.randomVector(VECTOR_DIMENSION); + vectors.add(vector); + doc.add(new IntField("id", id++, Field.Store.YES)); + doc.add(new KnnFloatVectorField(KNN_FIELD, vector, indexingSimilarityFunction)); + doc.add(new IntField("has_vector", 1, Field.Store.YES)); + } + w.addDocument(doc); + w.flush(); + } + } + // add a segment with no vectors + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + doc.add(new IntField("has_vector", 0, Field.Store.YES)); + w.addDocument(doc); + } + w.flush(); + } + + // index some 7 bit quantized vectors + try (IndexWriter w = + new IndexWriter( + dir, + newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(7))))) { + for (int j = 0; j < numSegments; j++) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + if (random().nextInt(100) < 30) { + // skip vector for some docs to create sparse vector field + doc.add(new IntField("has_vector", 0, Field.Store.YES)); + } else { + float[] vector = TestVectorUtil.randomVector(VECTOR_DIMENSION); + vectors.add(vector); + doc.add(new IntField("id", id++, Field.Store.YES)); + doc.add(new KnnFloatVectorField(KNN_FIELD, vector, indexingSimilarityFunction)); + doc.add(new IntField("has_vector", 1, Field.Store.YES)); + } + w.addDocument(doc); + w.flush(); + } + } + // add a segment with no vectors + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + doc.add(new IntField("has_vector", 0, Field.Store.YES)); + w.addDocument(doc); + } + w.flush(); + } + + float[] queryVector = TestVectorUtil.randomVector(VECTOR_DIMENSION); + VectorSimilarityFunction rerankSimilarityFunction; + try (IndexReader reader = DirectoryReader.open(dir)) { + for (LeafReaderContext ctx : reader.leaves()) { + DoubleValues fpSimValues; + if (random().nextBoolean()) { + rerankSimilarityFunction = + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length)]; + fpSimValues = + new FullPrecisionFloatVectorSimilarityValuesSource( + queryVector, KNN_FIELD, rerankSimilarityFunction) + .getSimilarityScores(ctx); + } else { + fpSimValues = + new FullPrecisionFloatVectorSimilarityValuesSource(queryVector, KNN_FIELD) + .getSimilarityScores(ctx); + rerankSimilarityFunction = indexingSimilarityFunction; + } + DoubleValues quantizedSimValues = + DoubleValuesSource.similarityToQueryVector(ctx, queryVector, KNN_FIELD); + // validate when segment has no vectors + if (fpSimValues == DoubleValues.EMPTY || quantizedSimValues == DoubleValues.EMPTY) { + assertEquals(fpSimValues, quantizedSimValues); + assertNull(ctx.reader().getFloatVectorValues(KNN_FIELD)); + continue; + } + StoredFields storedFields = ctx.reader().storedFields(); + VectorScorer quantizedScorer = + ctx.reader().getFloatVectorValues(KNN_FIELD).scorer(queryVector); + DocIdSetIterator disi = quantizedScorer.iterator(); + while (disi.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + int doc = disi.docID(); + fpSimValues.advanceExact(doc); + quantizedSimValues.advanceExact(doc); + int idValue = + Integer.parseInt(Objects.requireNonNull(storedFields.document(doc).get("id"))); + float[] docVector = vectors.get(idValue); + assert docVector != null : "Vector for id " + idValue + " not found"; + // validate full precision vector scores + double expectedFpScore = rerankSimilarityFunction.compare(queryVector, docVector); + double actualFpScore = fpSimValues.doubleValue(); + assertEquals(expectedFpScore, actualFpScore, 1e-5); + // validate quantized vector scores + double expectedQScore = quantizedScorer.score(); + double actualQScore = quantizedSimValues.doubleValue(); + assertEquals(expectedQScore, actualQScore, 1e-5); + } + } + } + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityValuesSource.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityValuesSource.java index 12123e58f937..1b78dab7874f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityValuesSource.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityValuesSource.java @@ -365,12 +365,12 @@ public void testFailuresWithSimilarityValuesSource() throws Exception { byte[] byteQueryVector = new byte[] {-10, 20, 30}; expectThrows( - IllegalArgumentException.class, + IllegalStateException.class, () -> DoubleValuesSource.similarityToQueryVector( searcher.reader.leaves().get(0), floatQueryVector, "knnByteField1")); expectThrows( - IllegalArgumentException.class, + IllegalStateException.class, () -> DoubleValuesSource.similarityToQueryVector( searcher.reader.leaves().get(0), byteQueryVector, "knnFloatField1"));