diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index 3bacd40482a5..f6fcb02208c5 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -177,6 +177,7 @@ private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { switch (info.getVectorEncoding()) { case BYTE -> Byte.BYTES; case FLOAT32 -> Float.BYTES; + case FLOAT16 -> Short.BYTES; }; long vectorBytes = Math.multiplyExact((long) dimension, byteSize); long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index b21df901ddb6..7ae9497129db 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -86,6 +86,7 @@ static OffHeapFloatVectorValues load( switch (fieldEntry.vectorEncoding()) { case BYTE -> fieldEntry.dimension(); case FLOAT32 -> fieldEntry.dimension() * Float.BYTES; + case FLOAT16 -> fieldEntry.dimension() * Short.BYTES; }; if (fieldEntry.docsWithFieldOffset() == -1) { return new DenseOffHeapVectorValues( diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index 20571783ab67..af571da211ca 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -177,6 +177,7 @@ private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { switch (info.getVectorEncoding()) { case BYTE -> Byte.BYTES; case FLOAT32 -> Float.BYTES; + case FLOAT16 -> Short.BYTES; }; long vectorBytes = Math.multiplyExact((long) dimension, byteSize); long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java index f2b07786967d..a1a82d248dc8 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java @@ -38,6 +38,7 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -186,4 +187,9 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { } } } + + @Override + protected VectorEncoding randomVectorEncoding() { + return random().nextBoolean() ? VectorEncoding.BYTE : VectorEncoding.FLOAT32; + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java index e7139e93b7c5..143c5601151e 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java @@ -40,6 +40,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.TopDocs; @@ -177,4 +178,9 @@ public void testSimpleOffHeapSize() throws IOException { } } } + + @Override + protected VectorEncoding randomVectorEncoding() { + return random().nextBoolean() ? VectorEncoding.BYTE : VectorEncoding.FLOAT32; + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index da62db66c89b..e92f7a98e9c8 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -175,6 +175,7 @@ private void writeField(FieldWriter fieldData, int maxDoc) throws IOException switch (fieldData.fieldInfo.getVectorEncoding()) { case BYTE -> writeByteVectors(fieldData); case FLOAT32 -> writeFloat32Vectors(fieldData); + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; @@ -240,6 +241,7 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM switch (fieldData.fieldInfo.getVectorEncoding()) { case BYTE -> writeSortedByteVectors(fieldData, ordMap); case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap); + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); }; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; @@ -404,6 +406,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE writeVectorData( tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); }; CodecUtil.writeFooter(tempVectorData); IOUtils.close(tempVectorData); @@ -460,6 +463,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); yield hnswGraphBuilder.build(vectorValues.size()); } + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); }; writeGraph(graph); } @@ -660,6 +664,7 @@ public float[] copyValue(float[] value) { return ArrayUtil.copyOfSubArray(value, 0, dim); } }; + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); }; } @@ -681,6 +686,7 @@ public float[] copyValue(float[] value) { defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), FloatVectorValues.fromFloats((List) vectors, dim)); + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/TestLucene94HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/TestLucene94HnswVectorsFormat.java index 393c4a427e25..cfab3164b3cc 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/TestLucene94HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/TestLucene94HnswVectorsFormat.java @@ -18,6 +18,7 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; public class TestLucene94HnswVectorsFormat extends BaseKnnVectorsFormatTestCase { @@ -38,4 +39,9 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { "Lucene94RWHnswVectorsFormat(name=Lucene94RWHnswVectorsFormat, maxConn=10, beamWidth=20)"; assertEquals(expectedString, customCodec.getKnnVectorsFormatForField("bogus_field").toString()); } + + @Override + protected VectorEncoding randomVectorEncoding() { + return random().nextBoolean() ? VectorEncoding.BYTE : VectorEncoding.FLOAT32; + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index 75abbb3e60b2..a565c3fd371e 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -179,6 +179,7 @@ private void writeField(FieldWriter fieldData, int maxDoc) throws IOException switch (fieldData.fieldInfo.getVectorEncoding()) { case BYTE -> writeByteVectors(fieldData); case FLOAT32 -> writeFloat32Vectors(fieldData); + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; @@ -245,6 +246,7 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM switch (fieldData.fieldInfo.getVectorEncoding()) { case BYTE -> writeSortedByteVectors(fieldData, ordMap); case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap); + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); }; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; @@ -431,6 +433,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE writeVectorData( tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); }; CodecUtil.writeFooter(tempVectorData); IOUtils.close(tempVectorData); @@ -475,8 +478,11 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE vectorDataInput, byteSize, defaultFlatVectorScorer, - fieldInfo.getVectorSimilarityFunction())); + fieldInfo.getVectorSimilarityFunction(), + VectorEncoding.FLOAT32)); break; + case FLOAT16: + throw new UnsupportedOperationException("FLOAT16 is not supported"); default: throw new IllegalArgumentException( "Unsupported vector encoding: " + fieldInfo.getVectorEncoding()); @@ -498,6 +504,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE case FLOAT32 -> mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); } graph = merger.merge( @@ -709,6 +716,7 @@ public float[] copyValue(float[] value) { return ArrayUtil.copyOfSubArray(value, 0, dim); } }; + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); }; } @@ -729,6 +737,7 @@ public float[] copyValue(float[] value) { defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), FloatVectorValues.fromFloats((List) vectors, dim)); + case FLOAT16 -> throw new UnsupportedOperationException("FLOAT16 is not supported"); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/TestLucene95HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/TestLucene95HnswVectorsFormat.java index a080e3bff7f7..c510621447ec 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/TestLucene95HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/TestLucene95HnswVectorsFormat.java @@ -18,6 +18,7 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; public class TestLucene95HnswVectorsFormat extends BaseKnnVectorsFormatTestCase { @@ -38,4 +39,9 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { "Lucene95RWHnswVectorsFormat(name=Lucene95RWHnswVectorsFormat, maxConn=10, beamWidth=20)"; assertEquals(expectedString, customCodec.getKnnVectorsFormatForField("bogus_field").toString()); } + + @Override + protected VectorEncoding randomVectorEncoding() { + return random().nextBoolean() ? VectorEncoding.BYTE : VectorEncoding.FLOAT32; + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index fb2ca112a0ab..ba16503aa4f4 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -40,6 +40,7 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.IndexSearcher; @@ -369,4 +370,9 @@ public void testVectorSimilarityFuncs() { var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); } + + @Override + protected VectorEncoding randomVectorEncoding() { + return random().nextBoolean() ? VectorEncoding.BYTE : VectorEncoding.FLOAT32; + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswScalarQuantizedVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswScalarQuantizedVectorsFormat.java index e2019719792f..971097c2df6e 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswScalarQuantizedVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswScalarQuantizedVectorsFormat.java @@ -29,6 +29,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -64,4 +65,9 @@ public void testSimpleOffHeapSize() throws IOException { } } } + + @Override + protected VectorEncoding randomVectorEncoding() { + return random().nextBoolean() ? VectorEncoding.BYTE : VectorEncoding.FLOAT32; + } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index ee9765f2ac0e..7a6082f02d52 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -42,6 +42,7 @@ import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -407,4 +408,9 @@ public void testRandomWithUpdatesAndGraph() { public void testSearchWithVisitedLimit() { // search not supported } + + @Override + protected VectorEncoding randomVectorEncoding() { + return random().nextBoolean() ? VectorEncoding.BYTE : VectorEncoding.FLOAT32; + } } diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerFloat32Benchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerFloat32Benchmark.java index ab9a5d976c27..8342e5e3e3b5 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerFloat32Benchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerFloat32Benchmark.java @@ -38,6 +38,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -326,7 +327,8 @@ static KnnVectorValues vectorValues( in.slice("test", 0, in.length()), byteSize, new ThrowingFlatVectorScorer(), - sim); + sim, + VectorEncoding.FLOAT32); } static final class ThrowingFlatVectorScorer implements FlatVectorsScorer { diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 9e2acc3caac8..2e49b8d51a19 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -19,7 +19,8 @@ @SuppressWarnings("module") // the test framework is compiled after the core... module org.apache.lucene.core { requires java.logging; - requires static jdk.management; // this is optional but explicit declaration is recommended + requires static jdk.management; + requires java.desktop; // this is optional but explicit declaration is recommended exports org.apache.lucene.analysis.standard; exports org.apache.lucene.analysis.tokenattributes; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index 96b0f75a259f..e19f67dca128 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -67,6 +67,8 @@ public byte[] copyValue(byte[] vectorValue) { } }; break; + case FLOAT16: + throw new UnsupportedOperationException("FLOAT16 is not supported"); default: throw new UnsupportedOperationException(); } @@ -105,6 +107,8 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { : bufferedByteVectorValues; writeField(fieldData.fieldInfo, byteVectorValues, maxDoc); break; + case FLOAT16: + throw new UnsupportedOperationException("FLOAT16 is not supported"); } } } @@ -207,6 +211,8 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); writeField(fieldInfo, byteVectorValues, mergeState.segmentInfo.maxDoc()); break; + case FLOAT16: + throw new UnsupportedOperationException("FLOAT16 is not supported"); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 50af32a7e162..1f29689cc2eb 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -66,7 +66,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE byteWriter.addValue(doc, mergedBytes.vectorValue(iter.index())); } } - case FLOAT32 -> { + case FLOAT32, FLOAT16 -> { KnnFieldVectorsWriter floatWriter = (KnnFieldVectorsWriter) addField(fieldInfo); FloatVectorValues mergedFloats = @@ -215,13 +215,19 @@ public static void mapOldOrdToNewOrd( public static final class MergedVectorValues { private MergedVectorValues() {} - private static void validateFieldEncoding(FieldInfo fieldInfo, VectorEncoding expected) { + private static void validateFieldEncoding(FieldInfo fieldInfo, VectorEncoding... expected) { assert fieldInfo != null && fieldInfo.hasVectorValues(); VectorEncoding fieldEncoding = fieldInfo.getVectorEncoding(); - if (fieldEncoding != expected) { - throw new UnsupportedOperationException( - "Cannot merge vectors encoded as [" + fieldEncoding + "] as " + expected); + for (VectorEncoding exp : expected) { + if (fieldEncoding == exp) { + return; + } } + throw new UnsupportedOperationException( + "Cannot merge vectors encoded as [" + + fieldEncoding + + "] as " + + Arrays.toString(expected)); } /** @@ -267,8 +273,8 @@ private static List mergeVectorValues( /** Returns a merged view over all the segment's {@link FloatVectorValues}. */ public static FloatVectorValues mergeFloatVectorValues( FieldInfo fieldInfo, MergeState mergeState) throws IOException { - validateFieldEncoding(fieldInfo, VectorEncoding.FLOAT32); - return new MergedFloat32VectorValues( + validateFieldEncoding(fieldInfo, VectorEncoding.FLOAT32, VectorEncoding.FLOAT16); + return new MergedFloatVectorValues( mergeVectorValues( mergeState.knnVectorsReaders, mergeState.docMaps, @@ -294,7 +300,7 @@ public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeS mergeState); } - static class MergedFloat32VectorValues extends FloatVectorValues { + static class MergedFloatVectorValues extends FloatVectorValues { private final List subs; private final DocIDMerger docIdMerger; private final int size; @@ -302,7 +308,7 @@ static class MergedFloat32VectorValues extends FloatVectorValues { private int lastOrd = -1; FloatVectorValuesSub current; - private MergedFloat32VectorValues(List subs, MergeState mergeState) + private MergedFloatVectorValues(List subs, MergeState mergeState) throws IOException { this.subs = subs; docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index 6a94eef33771..30c43f163498 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -40,7 +40,7 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { switch (vectorValues.getEncoding()) { - case FLOAT32 -> { + case FLOAT32, FLOAT16 -> { return new FloatScoringSupplier((FloatVectorValues) vectorValues, similarityFunction); } case BYTE -> { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java index 4e6288a0aa97..d259fd68f086 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -201,14 +201,17 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { if (fi == null) { return null; } - if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + if (fi.vectorEncoding != VectorEncoding.FLOAT32 + && fi.vectorEncoding != VectorEncoding.FLOAT16) { throw new IllegalArgumentException( "field=\"" + field + "\" is encoded as: " + fi.vectorEncoding + " expected: " - + VectorEncoding.FLOAT32); + + VectorEncoding.FLOAT32 + + " or " + + VectorEncoding.FLOAT16); } FloatVectorValues rawFloatVectorValues = rawVectorsReader.getFloatVectorValues(field); @@ -391,14 +394,17 @@ public org.apache.lucene.util.quantization.QuantizedByteVectorValues getQuantize if (fi == null) { return null; } - if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + if (fi.vectorEncoding != VectorEncoding.FLOAT32 + && fi.vectorEncoding != VectorEncoding.FLOAT16) { throw new IllegalArgumentException( "field=\"" + field + "\" is encoded as: " + fi.vectorEncoding + " expected: " - + VectorEncoding.FLOAT32); + + VectorEncoding.FLOAT32 + + " or " + + VectorEncoding.FLOAT16); } var qv = OffHeapScalarQuantizedVectorValues.load( diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index a579f588f4f7..0ecbf6cb6aee 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -124,7 +124,8 @@ public Lucene104ScalarQuantizedVectorsWriter( @Override public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); - if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + VectorEncoding vectorEncoding = fieldInfo.getVectorEncoding(); + if (vectorEncoding == VectorEncoding.FLOAT32 || vectorEncoding == VectorEncoding.FLOAT16) { @SuppressWarnings("unchecked") FieldWriter fieldWriter = new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); @@ -325,7 +326,8 @@ public void finish() throws IOException { @Override public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { - if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + VectorEncoding vectorEncoding = fieldInfo.getVectorEncoding(); + if (vectorEncoding != VectorEncoding.FLOAT32 && vectorEncoding != VectorEncoding.FLOAT16) { rawVectorDelegate.mergeOneField(fieldInfo, mergeState); return; } @@ -435,7 +437,8 @@ static DocsWithFieldSet writeBinarizedVectorAndQueryData( @Override public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( FieldInfo fieldInfo, MergeState mergeState) throws IOException { - if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + VectorEncoding vectorEncoding = fieldInfo.getVectorEncoding(); + if (vectorEncoding != VectorEncoding.FLOAT32 && vectorEncoding != VectorEncoding.FLOAT16) { return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); } @@ -650,7 +653,9 @@ static int mergeAndRecalculateCentroids( static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) throws IOException { - assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + + VectorEncoding vectorEncoding = fieldInfo.getVectorEncoding(); + assert vectorEncoding != VectorEncoding.FLOAT32 || vectorEncoding != VectorEncoding.FLOAT16; // clear out the centroid Arrays.fill(centroid, 0); int count = 0; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 7d221be3a907..0f8e7be1d15e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -42,6 +42,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme protected final float[] value; protected final VectorSimilarityFunction similarityFunction; protected final FlatVectorsScorer flatVectorsScorer; + protected final VectorEncoding vectorEncoding; OffHeapFloatVectorValues( int dimension, @@ -49,7 +50,8 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme IndexInput slice, int byteSize, FlatVectorsScorer flatVectorsScorer, - VectorSimilarityFunction similarityFunction) { + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding) { this.dimension = dimension; this.size = size; this.slice = slice; @@ -57,6 +59,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme this.similarityFunction = similarityFunction; this.flatVectorsScorer = flatVectorsScorer; value = new float[dimension]; + this.vectorEncoding = vectorEncoding; } @Override @@ -80,11 +83,25 @@ public float[] vectorValue(int targetOrd) throws IOException { return value; } slice.seek((long) targetOrd * byteSize); - slice.readFloats(value, 0, value.length); + + if (vectorEncoding == VectorEncoding.FLOAT16) { + short[] shortValues = new short[dimension]; + slice.readShorts(shortValues, 0, dimension); + for (int i = 0; i < dimension; i++) { + value[i] = Float.float16ToFloat(shortValues[i]); + } + } else { + slice.readFloats(value, 0, value.length); + } lastOrd = targetOrd; return value; } + @Override + public VectorEncoding getEncoding() { + return vectorEncoding; + } + public static OffHeapFloatVectorValues load( VectorSimilarityFunction vectorSimilarityFunction, FlatVectorsScorer flatVectorsScorer, @@ -95,11 +112,17 @@ public static OffHeapFloatVectorValues load( long vectorDataLength, IndexInput vectorData) throws IOException { - if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) { - return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); + if (configuration.docsWithFieldOffset == -2 + || (vectorEncoding != VectorEncoding.FLOAT32 && vectorEncoding != VectorEncoding.FLOAT16)) { + return new EmptyOffHeapVectorValues( + dimension, flatVectorsScorer, vectorSimilarityFunction, vectorEncoding); } IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); int byteSize = dimension * Float.BYTES; + if (vectorEncoding == VectorEncoding.FLOAT16) { + byteSize = dimension * Short.BYTES; + } + if (configuration.docsWithFieldOffset == -1) { return new DenseOffHeapVectorValues( dimension, @@ -107,7 +130,8 @@ public static OffHeapFloatVectorValues load( bytesSlice, byteSize, flatVectorsScorer, - vectorSimilarityFunction); + vectorSimilarityFunction, + vectorEncoding); } else { return new SparseOffHeapVectorValues( configuration, @@ -116,7 +140,8 @@ public static OffHeapFloatVectorValues load( dimension, byteSize, flatVectorsScorer, - vectorSimilarityFunction); + vectorSimilarityFunction, + vectorEncoding); } } @@ -132,14 +157,22 @@ public DenseOffHeapVectorValues( IndexInput slice, int byteSize, FlatVectorsScorer flatVectorsScorer, - VectorSimilarityFunction similarityFunction) { - super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding) { + super( + dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction, vectorEncoding); } @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( - dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + dimension, + size, + slice.clone(), + byteSize, + flatVectorsScorer, + similarityFunction, + vectorEncoding); } @Override @@ -196,10 +229,18 @@ public SparseOffHeapVectorValues( int dimension, int byteSize, FlatVectorsScorer flatVectorsScorer, - VectorSimilarityFunction similarityFunction) + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding) throws IOException { - super(dimension, configuration.size, slice, byteSize, flatVectorsScorer, similarityFunction); + super( + dimension, + configuration.size, + slice, + byteSize, + flatVectorsScorer, + similarityFunction, + vectorEncoding); this.configuration = configuration; final RandomAccessInput addressesData = dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength); @@ -224,7 +265,8 @@ public SparseOffHeapVectorValues copy() throws IOException { dimension, byteSize, flatVectorsScorer, - similarityFunction); + similarityFunction, + vectorEncoding); } @Override @@ -285,8 +327,9 @@ private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { public EmptyOffHeapVectorValues( int dimension, FlatVectorsScorer flatVectorsScorer, - VectorSimilarityFunction similarityFunction) { - super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding) { + super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction, vectorEncoding); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java index 607f1154875f..f15428a2e8da 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java @@ -21,6 +21,7 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; import java.io.IOException; +import java.util.Arrays; import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; @@ -192,23 +193,26 @@ private FieldEntry getFieldEntryOrThrow(String field) { return entry; } - private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + private FieldEntry getFieldEntry(String field, VectorEncoding... expectedEncoding) { final FieldEntry fieldEntry = getFieldEntryOrThrow(field); - if (fieldEntry.vectorEncoding != expectedEncoding) { - throw new IllegalArgumentException( - "field=\"" - + field - + "\" is encoded as: " - + fieldEntry.vectorEncoding - + " expected: " - + expectedEncoding); + for (VectorEncoding expected : expectedEncoding) { + if (fieldEntry.vectorEncoding == expected) { + return fieldEntry; + } } - return fieldEntry; + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fieldEntry.vectorEncoding + + " expected: " + + Arrays.toString(expectedEncoding)); } @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + final FieldEntry fieldEntry = + getFieldEntry(field, VectorEncoding.FLOAT32, VectorEncoding.FLOAT16); return OffHeapFloatVectorValues.load( fieldEntry.similarityFunction, vectorScorer, @@ -236,7 +240,8 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { - final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + final FieldEntry fieldEntry = + getFieldEntry(field, VectorEncoding.FLOAT32, VectorEncoding.FLOAT16); return vectorScorer.getRandomVectorScorer( fieldEntry.similarityFunction, OffHeapFloatVectorValues.load( @@ -315,6 +320,7 @@ private record FieldEntry( switch (info.getVectorEncoding()) { case BYTE -> Byte.BYTES; case FLOAT32 -> Float.BYTES; + case FLOAT16 -> Short.BYTES; }; long vectorBytes = Math.multiplyExact((long) infoVectorDimension, byteSize); long numBytes = Math.multiplyExact(vectorBytes, size); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index 1432f5ea46b8..a16b84f97ffd 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -159,6 +159,7 @@ private void writeField(FieldWriter fieldData, int maxDoc) throws IOException switch (fieldData.fieldInfo.getVectorEncoding()) { case BYTE -> writeByteVectors(fieldData); case FLOAT32 -> writeFloat32Vectors(fieldData); + case FLOAT16 -> writeFloat16Vectors(fieldData); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; @@ -175,6 +176,19 @@ private void writeFloat32Vectors(FieldWriter fieldData) throws IOException { } } + private void writeFloat16Vectors(FieldWriter fieldData) throws IOException { + final ByteBuffer buffer = + ByteBuffer.allocate(fieldData.dim * Short.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (Object v : fieldData.vectors) { + float[] vector = (float[]) v; + buffer.clear(); + for (float f : vector) { + buffer.putShort(Float.floatToFloat16(f)); + } + vectorData.writeBytes(buffer.array(), buffer.position()); + } + } + private void writeByteVectors(FieldWriter fieldData) throws IOException { for (Object v : fieldData.vectors) { byte[] vector = (byte[]) v; @@ -194,6 +208,7 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM switch (fieldData.fieldInfo.getVectorEncoding()) { case BYTE -> writeSortedByteVectors(fieldData, ordMap); case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap); + case FLOAT16 -> writeSortedFloat16Vectors(fieldData, ordMap); }; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; @@ -213,6 +228,22 @@ private long writeSortedFloat32Vectors(FieldWriter fieldData, int[] ordMap) return vectorDataOffset; } + private long writeSortedFloat16Vectors(FieldWriter fieldData, int[] ordMap) + throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Short.BYTES); + final ByteBuffer buffer = + ByteBuffer.allocate(fieldData.dim * Short.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] vector = (float[]) fieldData.vectors.get(ordinal); + buffer.clear(); + for (float f : vector) { + buffer.putShort(Float.floatToFloat16(f)); + } + vectorData.writeBytes(buffer.array(), buffer.position()); + } + return vectorDataOffset; + } + private long writeSortedByteVectors(FieldWriter fieldData, int[] ordMap) throws IOException { long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); for (int ordinal : ordMap) { @@ -239,6 +270,11 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE vectorData, KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues( fieldInfo, mergeState)); + case FLOAT16 -> + writeFloat16VectorData( + vectorData, + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues( + fieldInfo, mergeState)); }; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; writeMeta( @@ -271,6 +307,11 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( tempVectorData, KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues( fieldInfo, mergeState)); + case FLOAT16 -> + writeFloat16VectorData( + tempVectorData, + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues( + fieldInfo, mergeState)); }; CodecUtil.writeFooter(tempVectorData); IOUtils.close(tempVectorData); @@ -309,16 +350,17 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( fieldInfo.getVectorDimension() * Byte.BYTES, vectorsScorer, fieldInfo.getVectorSimilarityFunction())); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> vectorsScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), new OffHeapFloatVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), finalVectorDataInput, - fieldInfo.getVectorDimension() * Float.BYTES, + fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize, vectorsScorer, - fieldInfo.getVectorSimilarityFunction())); + fieldInfo.getVectorSimilarityFunction(), + fieldInfo.getVectorEncoding())); }; return new FlatCloseableRandomVectorScorerSupplier( () -> { @@ -394,6 +436,29 @@ private static DocsWithFieldSet writeVectorData( return docsWithField; } + /** + * Writes the vector values to the output and returns a set of documents that contains vectors. + */ + private static DocsWithFieldSet writeFloat16VectorData( + IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + ByteBuffer buffer = + ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT16.byteSize) + .order(ByteOrder.LITTLE_ENDIAN); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { + // write vector + float[] value = floatVectorValues.vectorValue(iter.index()); + buffer.clear(); + for (float f : value) { + buffer.putShort(Float.floatToFloat16(f)); + } + output.writeBytes(buffer.array(), buffer.position()); + docsWithField.add(docV); + } + return docsWithField; + } + @Override public void close() throws IOException { IOUtils.close(meta, vectorData); @@ -420,7 +485,7 @@ public byte[] copyValue(byte[] value) { return ArrayUtil.copyOfSubArray(value, 0, dim); } }; - case FLOAT32 -> + case FLOAT32, FLOAT16 -> new Lucene99FlatVectorsWriter.FieldWriter(fieldInfo) { @Override public float[] copyValue(float[] value) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index d7d95d45486e..c283cbde150f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -284,24 +284,27 @@ private FieldEntry getFieldEntryOrThrow(String field) { return entry; } - private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + private FieldEntry getFieldEntry(String field, VectorEncoding... expectedEncoding) { final FieldEntry fieldEntry = getFieldEntryOrThrow(field); - if (fieldEntry.vectorEncoding != expectedEncoding) { - throw new IllegalArgumentException( - "field=\"" - + field - + "\" is encoded as: " - + fieldEntry.vectorEncoding - + " expected: " - + expectedEncoding); + for (VectorEncoding expected : expectedEncoding) { + if (fieldEntry.vectorEncoding == expected) { + return fieldEntry; + } } - return fieldEntry; + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fieldEntry.vectorEncoding + + " expected: " + + Arrays.toString(expectedEncoding)); } @Override public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { - final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + final FieldEntry fieldEntry = + getFieldEntry(field, VectorEncoding.FLOAT32, VectorEncoding.FLOAT16); search( fieldEntry, knnCollector, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index d2526cff3ab8..f6a02775bce8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -436,7 +436,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE case BYTE -> mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); } @@ -655,7 +655,7 @@ static FieldWriter create( beamWidth, infoStream, tinySegmentsThreshold); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> new FieldWriter<>( scorer, (FlatFieldVectorsWriter) flatFieldVectorsWriter, @@ -691,7 +691,7 @@ static FieldWriter create( ByteVectorValues.fromBytes( (List) flatFieldVectorsWriter.getVectors(), fieldInfo.getVectorDimension())); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> scorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), FloatVectorValues.fromFloats( diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index 63a55ddc669f..94a26caa34eb 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -56,6 +56,33 @@ private static FieldType createType(float[] v, VectorSimilarityFunction similari return type; } + private static FieldType createType( + float[] v, VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding) { + if (v == null) { + throw new IllegalArgumentException("vector value must not be null"); + } + int dimension = v.length; + if (dimension == 0) { + throw new IllegalArgumentException("cannot index an empty vector"); + } + if (similarityFunction == null) { + throw new IllegalArgumentException("similarity function must not be null"); + } + + if (vectorEncoding == null) { + throw new IllegalArgumentException("Vector encoding must not be null"); + } + + if (vectorEncoding != VectorEncoding.FLOAT16 && vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException("Vector encoding must be FLOAT16 or FLOAT32"); + } + + FieldType type = new FieldType(); + type.setVectorAttributes(dimension, vectorEncoding, similarityFunction); + type.freeze(); + return type; + } + /** * A convenience method for creating a vector field type. * @@ -71,6 +98,22 @@ public static FieldType createFieldType( return type; } + /** + * A convenience method for creating a vector field type. + * + * @param dimension dimension of vectors. + * @param similarityFunction a function defining vector proximity. + * @param vectorEncoding the encoding format for the vector. Currently, supports FLOAT16 and + * FLOAT32. + */ + public static FieldType createFieldType( + int dimension, VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding) { + FieldType type = new FieldType(); + type.setVectorAttributes(dimension, vectorEncoding, similarityFunction); + type.freeze(); + return type; + } + /** * Create a new vector query for the provided field targeting the float vector * @@ -101,6 +144,24 @@ public KnnFloatVectorField( fieldsData = VectorUtil.checkFinite(vector); // null check done above } + /** + * Creates a new KnnFloatVectorField with the specified name, vector, similarity function, and + * encoding. + * + * @param name the field name + * @param vector the float vector value + * @param similarityFunction the similarity function to use for vector comparisons + * @param vectorEncoding the encoding format for the vector + */ + public KnnFloatVectorField( + String name, + float[] vector, + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding) { + super(name, createType(vector, similarityFunction, vectorEncoding)); + fieldsData = VectorUtil.checkFinite(vector); // null check done above + } + /** * Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are * single-valued: each document has either one value or no value. Vectors of a single field share @@ -127,7 +188,8 @@ public KnnFloatVectorField(String name, float[] vector) { */ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) { super(name, fieldType); - if (fieldType.vectorEncoding() != VectorEncoding.FLOAT32) { + if ((fieldType.vectorEncoding() != VectorEncoding.FLOAT32 + && fieldType.vectorEncoding() != VectorEncoding.FLOAT16)) { throw new IllegalArgumentException( "Attempt to create a vector for field " + name diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java index dc3242722636..fe423fb2d60b 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java +++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java @@ -2858,7 +2858,7 @@ public static Status.VectorValuesStatus testVectors( status, reader); break; - case FLOAT32: + case FLOAT32, FLOAT16: checkFloatVectorValues( Objects.requireNonNull(reader.getFloatVectorValues(fieldInfo.name)), fieldInfo, diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java index a39b05ee9829..f5b820747f85 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java @@ -236,7 +236,8 @@ public final FloatVectorValues getFloatVectorValues(String field) throws IOExcep FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0 - || fi.getVectorEncoding() != VectorEncoding.FLOAT32) { + || (fi.getVectorEncoding() != VectorEncoding.FLOAT32 + && fi.getVectorEncoding() != VectorEncoding.FLOAT16)) { // Field does not exist or does not index vectors return null; } @@ -266,7 +267,8 @@ public final void searchNearestVectors( FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0 - || fi.getVectorEncoding() != VectorEncoding.FLOAT32) { + || (fi.getVectorEncoding() != VectorEncoding.FLOAT32 + && fi.getVectorEncoding() != VectorEncoding.FLOAT16)) { // Field does not exist or does not index vectors return; } diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index b318f7c162bc..950b325fffa8 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -52,7 +52,10 @@ protected FloatVectorValues() {} */ public static void checkField(LeafReader in, String field) { FieldInfo fi = in.getFieldInfos().fieldInfo(field); - if (fi != null && fi.hasVectorValues() && fi.getVectorEncoding() != VectorEncoding.FLOAT32) { + if (fi != null + && fi.hasVectorValues() + && (fi.getVectorEncoding() != VectorEncoding.FLOAT32 + && fi.getVectorEncoding() != VectorEncoding.FLOAT16)) { throw new IllegalStateException( "Unexpected vector encoding (" + fi.getVectorEncoding() @@ -60,6 +63,8 @@ public static void checkField(LeafReader in, String field) { + field + "(expected=" + VectorEncoding.FLOAT32 + + " " + + VectorEncoding.FLOAT16 + ")"); } } diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java b/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java index 80d18cea435a..47a3b2c87a23 100644 --- a/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java +++ b/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java @@ -1030,7 +1030,7 @@ private void indexVectorValue( case BYTE -> ((KnnFieldVectorsWriter) pf.knnFieldVectorsWriter) .addValue(docID, ((KnnByteVectorField) field).vectorValue()); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> ((KnnFieldVectorsWriter) pf.knnFieldVectorsWriter) .addValue(docID, ((KnnFloatVectorField) field).vectorValue()); } diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorEncoding.java b/lucene/core/src/java/org/apache/lucene/index/VectorEncoding.java index 8ae6dd40e343..4425a35586a2 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorEncoding.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorEncoding.java @@ -29,7 +29,10 @@ public enum VectorEncoding { BYTE(1), /** Encodes vector using 32 bits of precision per sample in IEEE floating point format. */ - FLOAT32(4); + FLOAT32(4), + + /** Encodes vector using 16 bits of precision per sample in IEEE floating point format. */ + FLOAT16(2); /** * The number of bytes required to encode a scalar in this format. A vector will nominally require diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java index f6bac7b52bc6..c475795c1ed7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java @@ -180,7 +180,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti } else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors iterator = switch (fieldInfo.getVectorEncoding()) { - case FLOAT32 -> context.reader().getFloatVectorValues(field).iterator(); + case FLOAT32, FLOAT16 -> context.reader().getFloatVectorValues(field).iterator(); case BYTE -> context.reader().getByteVectorValues(field).iterator(); }; } else if (fieldInfo.getDocValuesType() @@ -282,7 +282,7 @@ private String buildErrorMsg(FieldInfo fieldInfo) { private int getVectorValuesSize(FieldInfo fi, LeafReader reader) throws IOException { assert fi.name.equals(field); return switch (fi.getVectorEncoding()) { - case FLOAT32 -> { + case FLOAT32, FLOAT16 -> { FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field); assert floatVectorValues != null : "unexpected null float vector values"; yield floatVectorValues.size(); diff --git a/lucene/core/src/java/org/apache/lucene/store/DataInput.java b/lucene/core/src/java/org/apache/lucene/store/DataInput.java index 62ea389a8357..7bcf13a234ee 100644 --- a/lucene/core/src/java/org/apache/lucene/store/DataInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/DataInput.java @@ -79,6 +79,20 @@ public short readShort() throws IOException { return (short) (((b2 & 0xFF) << 8) | (b1 & 0xFF)); } + /** + * Reads a specified number of shorts into an array at the specified offset. + * + * @param values the array to read shorts into + * @param offset the offset in the array to start storing shorts + * @param len the number of shorts to read + */ + public void readShorts(short[] values, int offset, int len) throws IOException { + Objects.checkFromIndexSize(offset, len, values.length); + for (int i = 0; i < len; i++) { + values[offset + i] = readShort(); + } + } + /** * Reads four bytes and returns an int (LE byte order). * diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java index 4b6244c18522..5fcb736cb2c5 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java @@ -107,7 +107,7 @@ private static int[] getNewOrdMapping( switch (fieldInfo.getVectorEncoding()) { case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name).iterator(); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> initializerIterator = initReader.getFloatVectorValues(fieldInfo.name).iterator(); } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java index dfedb66feda1..70a3985a4ad7 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java @@ -96,7 +96,7 @@ public IncrementalHnswGraphMerger addReader( KnnVectorValues knnVectorValues = switch (fieldInfo.getVectorEncoding()) { case BYTE -> reader.getByteVectorValues(fieldInfo.name); - case FLOAT32 -> reader.getFloatVectorValues(fieldInfo.name); + case FLOAT32, FLOAT16 -> reader.getFloatVectorValues(fieldInfo.name); }; int candidateVectorCount = countLiveVectors(liveDocs, knnVectorValues); @@ -171,7 +171,7 @@ protected final int[][] getNewOrdMapping( switch (fieldInfo.getVectorEncoding()) { case BYTE -> vectorsIter = graphReaders.get(i).reader.getByteVectorValues(fieldInfo.name).iterator(); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> vectorsIter = graphReaders.get(i).reader.getFloatVectorValues(fieldInfo.name).iterator(); } diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java index ade4a248c0db..609b1f479861 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java @@ -43,7 +43,8 @@ private Lucene99MemorySegmentFlatVectorsScorer(FlatVectorsScorer delegate) { public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityType, KnnVectorValues vectorValues) throws IOException { return switch (vectorValues.getEncoding()) { - case FLOAT32 -> getFloatScoringSupplier((FloatVectorValues) vectorValues, similarityType); + case FLOAT32, FLOAT16 -> + getFloatScoringSupplier((FloatVectorValues) vectorValues, similarityType); case BYTE -> getByteScorerSupplier((ByteVectorValues) vectorValues, similarityType); }; } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java index 9bea4fcd3877..86eaff57fd2e 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java @@ -359,7 +359,8 @@ FloatVectorValues floatVectorValues( in.slice("floatValues", 0, in.length()), dims * Float.BYTES, flatVectorsScorer, - sim); + sim, + VectorEncoding.FLOAT32); } /** Concatenates float arrays as byte[]. */ diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java index 86d10cb5af55..663c23d0144e 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java @@ -44,6 +44,7 @@ import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -412,7 +413,13 @@ KnnVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarity KnnVectorValues floatVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapFloatVectorValues.DenseOffHeapVectorValues( - dims, size, in.slice("floatValues", 0, in.length()), dims, MEMSEG_SCORER, sim); + dims, + size, + in.slice("floatValues", 0, in.length()), + dims, + MEMSEG_SCORER, + sim, + VectorEncoding.FLOAT32); } // creates the vector based on the given ordinal, which is reproducible given the ord and dims diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index ab469721edaa..64394f641a1c 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -831,7 +831,7 @@ public void testMergeAwayAllValues() throws IOException { case BYTE: vectorValues = leafReader.getByteVectorValues("field"); break; - case FLOAT32: + case FLOAT32, FLOAT16: vectorValues = leafReader.getFloatVectorValues("field"); break; default: diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java index 5432879a1341..e734021dbbc9 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java @@ -48,7 +48,7 @@ public void testFindAll() throws IOException { case BYTE: vectorScorer = context.reader().getByteVectorValues("field").scorer(new byte[] {1, 2}); break; - case FLOAT32: + case FLOAT32, FLOAT16: vectorScorer = context.reader().getFloatVectorValues("field").scorer(new float[] {1, 2}); break; default: diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index db3c590d920b..815f218251b0 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -123,7 +123,7 @@ protected RandomVectorScorer buildScorer(KnnVectorValues vectors, T query) throw return switch (getVectorEncoding()) { case BYTE -> flatVectorScorer.getRandomVectorScorer(similarityFunction, vectorsCopy, (byte[]) query); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> flatVectorScorer.getRandomVectorScorer(similarityFunction, vectorsCopy, (float[]) query); }; } @@ -175,7 +175,7 @@ public void testRandomReadWriteAndMerge() throws IOException { (T) ((ByteVectorValues) vectors).vectorValue(ord), similarityFunction)); } - case FLOAT32 -> { + case FLOAT32, FLOAT16 -> { doc.add( knnVectorField( "field", @@ -240,7 +240,7 @@ public void testGraphMergeWithDeletes() throws IOException { (T) ((ByteVectorValues) vectors).vectorValue(i), similarityFunction)); } - case FLOAT32 -> { + case FLOAT32, FLOAT16 -> { doc.add( knnVectorField( vectorFieldName, @@ -283,7 +283,7 @@ private T vectorValue(KnnVectorValues vectors, int ord) throws IOException { case BYTE -> { return (T) ((ByteVectorValues) vectors).vectorValue(ord); } - case FLOAT32 -> { + case FLOAT32, FLOAT16 -> { return (T) ((FloatVectorValues) vectors).vectorValue(ord); } } @@ -809,7 +809,7 @@ public void testFindAll() throws IOException { case BYTE -> similarityFunction.compare( ((ByteVectorValues) vectorValues).vectorValue(i), (byte[]) target); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> similarityFunction.compare( ((FloatVectorValues) vectorValues).vectorValue(i), (float[]) target); }; @@ -1254,6 +1254,71 @@ public VectorScorer scorer(float[] target) { } } + /** Returns vectors evenly distributed around the upper unit semicircle. */ + static class CircularFloat16VectorValues extends FloatVectorValues { + private final int size; + private final float[] value; + + int doc = -1; + + CircularFloat16VectorValues(int size) { + this.size = size; + value = new float[2]; + } + + @Override + public CircularFloat16VectorValues copy() { + return new CircularFloat16VectorValues(size); + } + + @Override + public int dimension() { + return 2; + } + + @Override + public int size() { + return size; + } + + public float[] vectorValue() { + return vectorValue(doc); + } + + public int docID() { + return doc; + } + + public int nextDoc() { + return advance(doc + 1); + } + + public int advance(int target) { + if (target >= 0 && target < size) { + doc = target; + } else { + doc = NO_MORE_DOCS; + } + return doc; + } + + @Override + public float[] vectorValue(int ord) { + return unitVector2d(ord / (double) size, value); + } + + private static float[] unitVector2d(double piRadians, float[] value) { + value[0] = Float.float16ToFloat(Float.floatToFloat16((float) Math.cos(Math.PI * piRadians))); + value[1] = Float.float16ToFloat(Float.floatToFloat16((float) Math.sin(Math.PI * piRadians))); + return value; + } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } + } + /** Returns vectors evenly distributed around the upper unit semicircle. */ static class CircularByteVectorValues extends ByteVectorValues { private final int size; @@ -1351,7 +1416,7 @@ void assertVectorsEqual(KnnVectorValues u, KnnVectorValues v) throws IOException "vectors do not match for doc=" + uDoc, (byte[]) vectorValue(u, ord), (byte[]) vectorValue(v, ord)); - case FLOAT32 -> + case FLOAT32, FLOAT16 -> assertArrayEquals( "vectors do not match for doc=" + uDoc, (float[]) vectorValue(u, ord), @@ -1371,6 +1436,14 @@ static float[][] createRandomFloatVectors(int size, int dimension, Random random return vectors; } + static float[][] createRandomFloat16Vectors(int size, int dimension, Random random) { + float[][] vectors = new float[size][]; + for (int offset = 0; offset < size; offset++) { + vectors[offset] = randomFloat16Vector(random, dimension); + } + return vectors; + } + static byte[][] createRandomByteVectors(int size, int dimension, Random random) { byte[][] vectors = new byte[size][]; for (int offset = 0; offset < size; offset++) { @@ -1410,6 +1483,14 @@ static float[] randomVector(Random random, int dim) { return vec; } + static float[] randomFloat16Vector(Random random, int dim) { + float[] vec = randomVector(random, dim); + for (int i = 0; i < dim; i++) { + vec[i] = Float.float16ToFloat(Float.floatToFloat16(vec[i])); + } + return vec; + } + static byte[] randomVector8(Random random, int dim) { float[] fvec = randomVector(random, dim); byte[] bvec = new byte[dim]; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloat16VectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloat16VectorGraph.java new file mode 100644 index 000000000000..043559f1ae39 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloat16VectorGraph.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import java.io.IOException; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.FixedBitSet; +import org.junit.Before; + +/** Tests HNSW KNN graphs */ +public class TestHnswFloat16VectorGraph extends HnswGraphTestCase { + + @Before + public void setup() { + similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values()); + } + + @Override + VectorEncoding getVectorEncoding() { + return VectorEncoding.FLOAT16; + } + + @Override + Query knnQuery(String field, float[] vector, int k) { + return new KnnFloatVectorQuery(field, vector, k); + } + + @Override + float[] randomVector(int dim) { + return randomFloat16Vector(random(), dim); + } + + @Override + MockVectorValues vectorValues(int size, int dimension) { + return MockVectorValues.fromValues(createRandomFloat16Vectors(size, dimension, random())); + } + + @Override + MockVectorValues vectorValues(float[][] values) { + return MockVectorValues.fromValues(values); + } + + @Override + MockVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException { + FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName); + float[][] vectors = new float[reader.maxDoc()][]; + for (int i = 0; i < vectorValues.size(); i++) { + vectors[vectorValues.ordToDoc(i)] = + ArrayUtil.copyOfSubArray(vectorValues.vectorValue(i), 0, vectorValues.dimension()); + } + return MockVectorValues.fromValues(vectors); + } + + @Override + MockVectorValues vectorValues( + int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset) { + MockVectorValues pvv = (MockVectorValues) pregeneratedVectorValues; + float[][] vectors = new float[size][]; + float[][] randomVectors = + createRandomFloat16Vectors(size - pvv.values.length, dimension, random()); + + for (int i = 0; i < pregeneratedOffset; i++) { + vectors[i] = randomVectors[i]; + } + + for (int currentOrd = 0; currentOrd < pvv.size(); currentOrd++) { + vectors[pregeneratedOffset + currentOrd] = pvv.values[currentOrd]; + } + + for (int i = pregeneratedOffset + pvv.values.length; i < vectors.length; i++) { + vectors[i] = randomVectors[i - pvv.values.length]; + } + + return MockVectorValues.fromValues(vectors); + } + + @Override + Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnFloatVectorField(name, vector, similarityFunction, VectorEncoding.FLOAT16); + } + + @Override + CircularFloat16VectorValues circularVectorValues(int nDoc) { + return new CircularFloat16VectorValues(nDoc); + } + + @Override + float[] getTargetVector() { + return new float[] {1f, 0f}; + } + + public void testSearchWithSkewedAcceptOrds() throws IOException { + int nDoc = 1000; + similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + FloatVectorValues vectors = circularVectorValues(nDoc); + RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); + HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); + OnHeapHnswGraph hnsw = builder.build(vectors.size()); + + // Skip over half of the documents that are closest to the query vector + FixedBitSet acceptOrds = new FixedBitSet(nDoc); + for (int i = 500; i < nDoc; i++) { + acceptOrds.set(i); + } + KnnCollector nn = + HnswGraphSearcher.search( + buildScorer(vectors, getTargetVector()), 10, hnsw, acceptOrds, Integer.MAX_VALUE); + + TopDocs nodes = nn.topDocs(); + assertEquals("Number of found results is not equal to [10].", 10, nodes.scoreDocs.length); + int sum = 0; + for (ScoreDoc node : nodes.scoreDocs) { + assertTrue("the results include a deleted document: " + node, acceptOrds.get(node.doc)); + sum += node.doc; + } + // We still expect to get reasonable recall. The lowest non-skipped docIds + // are closest to the query vector: sum(500,509) = 5045 + assertTrue("sum(result docs)=" + sum, sum < 5100); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 52d1da3dfa83..e4728f904a55 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -104,7 +104,7 @@ MockVectorValues vectorValues( @Override Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) { - return new KnnFloatVectorField(name, vector, similarityFunction); + return new KnnFloatVectorField(name, vector, similarityFunction, VectorEncoding.FLOAT32); } @Override diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 145a6c8ce6f0..2712fcdb106e 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -813,7 +813,7 @@ private void storeVectorValues(Info info, IndexableField vectorField) { + vectorField.name() + "] is not a byte vector field, but the field info is configured for byte vectors"); } - case FLOAT32 -> { + case FLOAT32, FLOAT16 -> { if (vectorField instanceof KnnFloatVectorField floatVectorField) { if (info.floatVectorCount == 1) { throw new IllegalArgumentException( diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java index f41986d8e8ce..d1b6d534d8ca 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java @@ -103,6 +103,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE // TODO: Support using SQ8 quantization, see: // - https://github.com/opensearch-project/k-NN/pull/2425 throw new UnsupportedOperationException("Byte vectors not supported"); + case FLOAT16 -> throw new UnsupportedOperationException("Float16 vectors not supported"); case FLOAT32 -> { FloatVectorValues merged = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); @@ -129,6 +130,8 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { // - https://github.com/opensearch-project/k-NN/pull/2425 throw new UnsupportedOperationException("Byte vectors not supported"); + case FLOAT16 -> throw new UnsupportedOperationException("Float16 vectors not supported"); + case FLOAT32 -> { @SuppressWarnings("unchecked") FlatFieldVectorsWriter rawWriter = diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index 97cb31911ba9..5f6f6b747f67 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -134,7 +134,8 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 - && fi.getVectorEncoding() == VectorEncoding.FLOAT32; + && (fi.getVectorEncoding() == VectorEncoding.FLOAT32 + || fi.getVectorEncoding() == VectorEncoding.FLOAT16); FloatVectorValues floatValues = delegate.getFloatVectorValues(field); assert floatValues != null; assert floatValues.iterator().docID() == -1; @@ -164,7 +165,8 @@ public void search( FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 - && fi.getVectorEncoding() == VectorEncoding.FLOAT32; + && (fi.getVectorEncoding() == VectorEncoding.FLOAT32 + || fi.getVectorEncoding() == VectorEncoding.FLOAT16); acceptDocs = AssertingAcceptDocs.wrap(acceptDocs); delegate.search(field, target, knnCollector, acceptDocs); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index dacfeeaf2661..f82074274f87 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -129,6 +129,9 @@ protected void addRandomFields(Document doc) { case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction)); case FLOAT32 -> doc.add(new KnnFloatVectorField("v2", randomNormalizedVector(30), similarityFunction)); + case FLOAT16 -> + doc.add( + new KnnFloatVectorField("v2", randomNormalizedFloat16Vector(30), similarityFunction)); } } @@ -894,7 +897,16 @@ public void testSparseVectors() throws Exception { } case FLOAT32 -> { float[] v = randomNormalizedVector(fieldDims[field]); - doc.add(new KnnFloatVectorField(fieldName, v, fieldSimilarityFunctions[field])); + doc.add( + new KnnFloatVectorField( + fieldName, v, fieldSimilarityFunctions[field], VectorEncoding.FLOAT32)); + fieldTotals[field] += v[0]; + } + case FLOAT16 -> { + float[] v = randomNormalizedFloat16Vector(fieldDims[field]); + doc.add( + new KnnFloatVectorField( + fieldName, v, fieldSimilarityFunctions[field], VectorEncoding.FLOAT16)); fieldTotals[field] += v[0]; } } @@ -922,7 +934,7 @@ public void testSparseVectors() throws Exception { } } } - case FLOAT32 -> { + case FLOAT32, FLOAT16 -> { for (LeafReaderContext ctx : r.leaves()) { FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); if (vectorValues != null) { @@ -1697,6 +1709,30 @@ public static float[] randomVector(int dim) { return v; } + public static float[] randomFloat16Vector(int dim) { + assert dim > 0; + float[] v = new float[dim]; + double squareSum = 0.0; + // keep generating until we don't get a zero-length vector + while (squareSum == 0.0) { + squareSum = 0.0; + for (int i = 0; i < dim; i++) { + v[i] = Float.float16ToFloat(Float.floatToFloat16(random().nextFloat())); + squareSum += v[i] * v[i]; + } + } + return v; + } + + public static float[] randomNormalizedFloat16Vector(int dim) { + float[] v = randomVector(dim); + VectorUtil.l2normalize(v); + for (int i = 0; i < v.length; i++) { + v[i] = Float.float16ToFloat(Float.floatToFloat16(v[i])); + } + return v; + } + public static float[] randomNormalizedVector(int dim) { float[] v = randomVector(dim); VectorUtil.l2normalize(v); @@ -1759,7 +1795,8 @@ public void testVectorEncodingOrdinals() { // enumerators assertEquals(0, VectorEncoding.BYTE.ordinal()); assertEquals(1, VectorEncoding.FLOAT32.ordinal()); - assertEquals(2, VectorEncoding.values().length); + assertEquals(2, VectorEncoding.FLOAT16.ordinal()); + assertEquals(3, VectorEncoding.values().length); } public void testAdvance() throws Exception { @@ -1838,7 +1875,16 @@ public void testVectorValuesReportCorrectDocs() throws Exception { case FLOAT32 -> { float[] v = randomNormalizedVector(dim); fieldValuesCheckSum += v[0]; - doc.add(new KnnFloatVectorField("knn_vector", v, similarityFunction)); + doc.add( + new KnnFloatVectorField( + "knn_vector", v, similarityFunction, VectorEncoding.FLOAT32)); + } + case FLOAT16 -> { + float[] v = randomNormalizedFloat16Vector(dim); + fieldValuesCheckSum += v[0]; + doc.add( + new KnnFloatVectorField( + "knn_vector", v, similarityFunction, VectorEncoding.FLOAT16)); } } fieldDocCount++; @@ -1879,7 +1925,7 @@ public void testVectorValuesReportCorrectDocs() throws Exception { } } } - case FLOAT32 -> { + case FLOAT32, FLOAT16 -> { for (LeafReaderContext ctx : r.leaves()) { FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues("knn_vector"); if (vectorValues != null) { @@ -2182,7 +2228,7 @@ protected void assertOffHeapByteSize(LeafReader r, String fieldName) throws IOEx static int getNumVectors(KnnVectorsReader reader, FieldInfo fieldInfo) throws IOException { return switch (fieldInfo.getVectorEncoding()) { case BYTE -> reader.getByteVectorValues(fieldInfo.getName()).size(); - case FLOAT32 -> reader.getFloatVectorValues(fieldInfo.getName()).size(); + case FLOAT32, FLOAT16 -> reader.getFloatVectorValues(fieldInfo.getName()).size(); }; }