diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java new file mode 100644 index 0000000000000..8d25ab54d8ca1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.util.BitUtil; + +import java.nio.ByteOrder; +import java.nio.ShortBuffer; + +public final class BFloat16 { + + public static final int BYTES = Short.BYTES; + + public static short floatToBFloat16(float f) { + // this rounds towards 0 + // zero - zero exp, zero fraction + // denormal - zero exp, non-zero fraction + // infinity - all-1 exp, zero fraction + // NaN - all-1 exp, non-zero fraction + // the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into + // infinities + return (short) (Float.floatToIntBits(f) >>> 16); + } + + public static float bFloat16ToFloat(short bf) { + return Float.intBitsToFloat(bf << 16); + } + + public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) { + assert bFloats.remaining() == floats.length; + assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; + for (float v : floats) { + bFloats.put(floatToBFloat16(v)); + } + } + + public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) { + assert floats.length * 2 == bfBytes.length; + for (int i = 0; i < floats.length; i++) { + floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2)); + } + } + + public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) { + assert floats.length == bFloats.remaining(); + assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; + for (int i = 0; i < floats.length; i++) { + floats[i] = bFloat16ToFloat(bFloats.get()); + } + } + + private BFloat16() {} +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java index cba55f8a7e942..c200828876e85 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java @@ -26,7 +26,8 @@ /** Utility class for vector quantization calculations */ public class BQVectorUtils { - private static final float EPSILON = 1e-4f; + // NOTE: this is currently > 1e-4f due to bfloat16 + private static final float EPSILON = 1e-2f; public static double sqrtNewtonRaphson(double x, double curr, double prev) { return (curr == prev) ? curr : sqrtNewtonRaphson(x, 0.5 * (curr + x / curr), curr); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java index 0d67281bf5606..fade2384441cf 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java @@ -11,18 +11,81 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.FlushInfo; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.MergeInfo; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; +import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; +import java.util.Set; public abstract class DirectIOCapableFlatVectorsFormat extends AbstractFlatVectorsFormat { protected DirectIOCapableFlatVectorsFormat(String name) { super(name); } + protected abstract FlatVectorsReader createReader(SegmentReadState state) throws IOException; + + protected static boolean canUseDirectIO(SegmentReadState state) { + return FsDirectoryFactory.isHybridFs(state.directory); + } + @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { return fieldsReader(state, false); } - public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException; + public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException { + if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) { + // only override the context for the random-access use case + SegmentReadState directIOState = new SegmentReadState( + state.directory, + state.segmentInfo, + state.fieldInfos, + new DirectIOContext(state.context.hints()), + state.segmentSuffix + ); + // Use mmap for merges and direct I/O for searches. + return new MergeReaderWrapper(createReader(directIOState), createReader(state)); + } else { + return createReader(state); + } + } + + protected static class DirectIOContext implements IOContext { + + final Set hints; + + public DirectIOContext(Set hints) { + // always add DirectIOHint to the hints given + this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE)); + } + + @Override + public Context context() { + return Context.DEFAULT; + } + + @Override + public MergeInfo mergeInfo() { + return null; + } + + @Override + public FlushInfo flushInfo() { + return null; + } + + @Override + public Set hints() { + return hints; + } + + @Override + public IOContext withHints(FileOpenHint... hints) { + return new DirectIOContext(Set.of(hints)); + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java index 293cb61e9105c..eac3e708dfe66 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java @@ -109,7 +109,7 @@ public QuantizationResult[] multiScalarQuantize( } public QuantizationResult scalarQuantize(float[] vector, float[] residualDestination, int[] destination, byte bits, float[] centroid) { - assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || BQVectorUtils.isUnitVector(vector); assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); assert vector.length <= destination.length; assert bits > 0 && bits <= 8; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index 64796d10662e4..99bc9a9d7bdb2 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -58,12 +58,12 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { public static final int VERSION_DIRECT_IO = 1; public static final int VERSION_CURRENT = VERSION_DIRECT_IO; - private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); private static final Map supportedFormats = Map.of( - rawVectorFormat.getName(), - rawVectorFormat + float32VectorFormat.getName(), + float32VectorFormat ); // This dynamically sets the cluster probe based on the `k` requested and the number of clusters. @@ -79,6 +79,7 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private final int vectorPerCluster; private final int centroidsPerParentCluster; private final boolean useDirectIO; + private final DirectIOCapableFlatVectorsFormat rawVectorFormat; public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { this(vectorPerCluster, centroidsPerParentCluster, false); @@ -109,6 +110,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; this.useDirectIO = useDirectIO; + this.rawVectorFormat = float32VectorFormat; } /** Constructs a format using the given graph construction parameters and scalar quantization. */ diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index fe26de0fe869c..fca247f084adf 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -27,6 +27,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.simdvec.ESVectorUtil; @@ -70,7 +71,7 @@ public RandomVectorScorer getRandomVectorScorer( assert binarizedVectors.size() > 0 : "BinarizedByteVectorValues must have at least one vector for ES816BinaryFlatVectorsScorer"; OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer(); float[] centroid = binarizedVectors.getCentroid(); - assert similarityFunction != COSINE || VectorUtil.isUnitVector(target); + assert similarityFunction != COSINE || BQVectorUtils.isUnitVector(target); float[] scratch = new float[vectorValues.dimension()]; int[] initial = new int[target.length]; byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java index 8d9e8bf448adc..b67f7186b8f4d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java @@ -25,23 +25,17 @@ import org.apache.lucene.search.DocAndFloatFeatureBuffer; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; -import org.apache.lucene.store.FlushInfo; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.MergeInfo; import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomVectorScorer; -import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues; import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues; import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; -import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; -import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; import java.util.List; -import java.util.Set; public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat { @@ -61,17 +55,13 @@ protected FlatVectorsScorer flatVectorsScorer() { } @Override - public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99FlatVectorsWriter(state, vectorsScorer); - } - - static boolean canUseDirectIO(SegmentReadState state) { - return FsDirectoryFactory.isHybridFs(state.directory); + protected FlatVectorsReader createReader(SegmentReadState state) throws IOException { + return new Lucene99FlatVectorsReader(state, vectorsScorer); } @Override - public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return fieldsReader(state, false); + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99FlatVectorsWriter(state, vectorsScorer); } @Override @@ -99,41 +89,6 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI } } - static class DirectIOContext implements IOContext { - - final Set hints; - - DirectIOContext(Set hints) { - // always add DirectIOHint to the hints given - this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE)); - } - - @Override - public Context context() { - return Context.DEFAULT; - } - - @Override - public MergeInfo mergeInfo() { - return null; - } - - @Override - public FlushInfo flushInfo() { - return null; - } - - @Override - public Set hints() { - return hints; - } - - @Override - public IOContext withHints(FileOpenHint... hints) { - return new DirectIOContext(Set.of(hints)); - } - } - static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader { private final Lucene99FlatVectorsReader inner; private final SegmentReadState state; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsFormat.java new file mode 100644 index 0000000000000..c6b2f61a366e9 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsFormat.java @@ -0,0 +1,64 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; + +import java.io.IOException; + +public final class ES93BFloat16FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat { + + static final String NAME = "ES93BFloat16FlatVectorsFormat"; + static final String META_CODEC_NAME = "ES93BFloat16FlatVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES93BFloat16FlatVectorsFormatData"; + static final String META_EXTENSION = "vemf"; + static final String VECTOR_DATA_EXTENSION = "vec"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + private final FlatVectorsScorer vectorsScorer; + + public ES93BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) { + super(NAME); + this.vectorsScorer = vectorsScorer; + } + + @Override + protected FlatVectorsScorer flatVectorsScorer() { + return vectorsScorer; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES93BFloat16FlatVectorsWriter(state, vectorsScorer); + } + + @Override + protected FlatVectorsReader createReader(SegmentReadState state) throws IOException { + return new ES93BFloat16FlatVectorsReader(state, vectorsScorer); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsReader.java new file mode 100644 index 0000000000000..c71470d6be15e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsReader.java @@ -0,0 +1,325 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.vectors.BFloat16; + +import java.io.IOException; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +public final class ES93BFloat16FlatVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES93BFloat16FlatVectorsReader.class); + + private final IntObjectHashMap fields = new IntObjectHashMap<>(); + private final IndexInput vectorData; + private final FieldInfos fieldInfos; + private final IOContext dataContext; + + public ES93BFloat16FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; + boolean success = false; + // Flat formats are used to randomly access vectors from their node ID that is stored + // in the HNSW graph. + dataContext = state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM); + try { + vectorData = openDataInput( + state, + versionMeta, + ES93BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, + ES93BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + dataContext + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private int readMetadata(SegmentReadState state) throws IOException { + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES93BFloat16FlatVectorsFormat.META_EXTENSION + ); + int versionMeta = -1; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + meta, + ES93BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES93BFloat16FlatVectorsFormat.VERSION_START, + ES93BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + } + return versionMeta; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + ES93BFloat16FlatVectorsFormat.VERSION_START, + ES93BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = FieldEntry.create(meta, info); + fields.put(info.number, fieldEntry); + } + } + + @Override + public long ramBytesUsed() { + return ES93BFloat16FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed(); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + final FieldEntry entry = getFieldEntryOrThrow(fieldInfo.name); + return Map.of(ES93BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, entry.vectorDataLength()); + } + + @Override + public void checkIntegrity() throws IOException { + CodecUtil.checksumEntireFile(vectorData); + } + + @Override + public FlatVectorsReader getMergeInstance() throws IOException { + // Update the read advice since vectors are guaranteed to be accessed sequentially for merge + vectorData.updateIOContext(dataContext.withHints(DataAccessHint.SEQUENTIAL)); + return this; + } + + private FieldEntry getFieldEntryOrThrow(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry entry; + if (info == null || (entry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } + return entry; + } + + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldEntry fieldEntry = getFieldEntryOrThrow(field); + if (fieldEntry.vectorEncoding != expectedEncoding) { + throw new IllegalArgumentException( + "field=\"" + field + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " + expectedEncoding + ); + } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return OffHeapBFloat16VectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + throw new IllegalStateException(field + " only supports float vectors"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, + OffHeapBFloat16VectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ), + target + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + throw new UnsupportedOperationException(field + " only supports float vectors"); + } + + @Override + public void finishMerge() throws IOException { + // This makes sure that the access pattern hint is reverted back since HNSW implementation + // needs it + vectorData.updateIOContext(dataContext); + } + + @Override + public void close() throws IOException { + IOUtils.close(vectorData); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long vectorDataOffset, + long vectorDataLength, + int dimension, + int size, + OrdToDocDISIReaderConfiguration ordToDoc, + FieldInfo info + ) { + + FieldEntry { + if (vectorEncoding == VectorEncoding.BYTE) { + throw new IllegalStateException( + "Incorrect vector encoding for field=\"" + info.name + "\"; " + vectorEncoding + " not supported" + ); + } + + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + int infoVectorDimension = info.getVectorDimension(); + if (infoVectorDimension != dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + info.name + "\"; " + infoVectorDimension + " != " + dimension + ); + } + + int byteSize = BFloat16.BYTES; + long vectorBytes = Math.multiplyExact((long) infoVectorDimension, byteSize); + long numBytes = Math.multiplyExact(vectorBytes, size); + if (numBytes != vectorDataLength) { + throw new IllegalStateException( + "Vector data length " + + vectorDataLength + + " not matching size=" + + size + + " * dim=" + + dimension + + " * byteSize=" + + byteSize + + " = " + + numBytes + ); + } + } + + static FieldEntry create(IndexInput input, FieldInfo info) throws IOException { + final VectorEncoding vectorEncoding = readVectorEncoding(input); + final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + final var vectorDataOffset = input.readVLong(); + final var vectorDataLength = input.readVLong(); + final var dimension = input.readVInt(); + final var size = input.readInt(); + final var ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry(similarityFunction, vectorEncoding, vectorDataOffset, vectorDataLength, dimension, size, ordToDoc, info); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java new file mode 100644 index 0000000000000..3c143d94fd6b5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java @@ -0,0 +1,434 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.vectors.BFloat16; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; + +public final class ES93BFloat16FlatVectorsWriter extends FlatVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ES93BFloat16FlatVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final IndexOutput meta, vectorData; + + private final List> fields = new ArrayList<>(); + private boolean finished; + + public ES93BFloat16FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES93BFloat16FlatVectorsFormat.META_EXTENSION + ); + + String vectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES93BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION + ); + + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES93BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES93BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + vectorData, + ES93BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + ES93BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FieldWriter newField = FieldWriter.create(fieldInfo); + fields.add(newField); + return newField; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + for (FieldWriter field : fields) { + if (sortMap == null) { + writeField(field, maxDoc); + } else { + writeSortingField(field, maxDoc, sortMap); + } + field.finish(); + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + total += field.ramBytesUsed(); + } + return total; + } + + private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + switch (fieldData.fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeBFloat16Vectors(fieldData); + case BYTE -> throw new IllegalStateException( + "Incorrect encoding for field " + fieldData.fieldInfo.name + ": " + VectorEncoding.BYTE + ); + } + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField); + } + + private void writeBFloat16Vectors(FieldWriter fieldData) throws IOException { + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (Object v : fieldData.vectors) { + BFloat16.floatToBFloat16((float[]) v, buffer.asShortBuffer()); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + } + + private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) throws IOException { + final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = switch (fieldData.fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeSortedBFloat16Vectors(fieldData, ordMap); + case BYTE -> throw new IllegalStateException( + "Incorrect encoding for field " + fieldData.fieldInfo.name + ": " + VectorEncoding.BYTE + ); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField); + } + + private long writeSortedBFloat16Vectors(FieldWriter fieldData, int[] ordMap) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] vector = (float[]) fieldData.vectors.get(ordinal); + BFloat16.floatToBFloat16(vector, buffer.asShortBuffer()); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + return vectorDataOffset; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + // Since we know we will not be searching for additional indexing, we can just write the + // the vectors directly to the new segment. + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + // No need to use temporary file as we don't have to re-open for reading + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeVectorData(vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case BYTE -> throw new IllegalStateException("Incorrect encoding for field " + fieldInfo.name + ": " + VectorEncoding.BYTE); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + IndexOutput tempVectorData = segmentWriteState.directory.createTempOutput(vectorData.getName(), "temp", segmentWriteState.context); + IndexInput vectorDataInput = null; + boolean success = false; + try { + // write the vector data to a temporary file + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeVectorData(tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case BYTE -> throw new UnsupportedOperationException("ES92BFloat16FlatVectorsWriter only supports float vectors"); + }; + CodecUtil.writeFooter(tempVectorData); + IOUtils.close(tempVectorData); + + // This temp file will be accessed in a random-access fashion to construct the HNSW graph. + // Note: don't use the context from the state, which is a flush/merge context, not expecting + // to perform random reads. + vectorDataInput = segmentWriteState.directory.openInput( + tempVectorData.getName(), + IOContext.DEFAULT.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + ); + // copy the temporary file vectors to the actual data file + vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength()); + CodecUtil.retrieveChecksum(vectorDataInput); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + success = true; + final IndexInput finalVectorDataInput = vectorDataInput; + final RandomVectorScorerSupplier randomVectorScorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapFloatVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * BFloat16.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction() + ) + ); + return new FlatCloseableRandomVectorScorerSupplier(() -> { + IOUtils.close(finalVectorDataInput); + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + }, docsWithField.cardinality(), randomVectorScorerSupplier); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(vectorDataInput, tempVectorData); + try { + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + } catch (Exception e) { + // ignore + } + } + } + } + + private void writeMeta(FieldInfo field, int maxDoc, long vectorDataOffset, long vectorDataLength, DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + meta.writeVInt(field.getVectorDimension()); + + // write docIDs + int count = docsWithField.cardinality(); + meta.writeInt(count); + OrdToDocDISIReaderConfiguration.writeStoredMeta(DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); + } + + /** + * Writes the vector values to the output and returns a set of documents that contains vectors. + */ + private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { + // write vector + float[] value = floatVectorValues.vectorValue(iter.index()); + BFloat16.floatToBFloat16(value, buffer.asShortBuffer()); + output.writeBytes(buffer.array(), buffer.limit()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData); + } + + private abstract static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private final int dim; + private final DocsWithFieldSet docsWithField; + private final List vectors; + private boolean finished; + + private int lastDocID = -1; + + static FieldWriter create(FieldInfo fieldInfo) { + int dim = fieldInfo.getVectorDimension(); + return switch (fieldInfo.getVectorEncoding()) { + case FLOAT32 -> new ES93BFloat16FlatVectorsWriter.FieldWriter(fieldInfo) { + @Override + public float[] copyValue(float[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + case BYTE -> throw new IllegalStateException("Incorrect encoding for field " + fieldInfo.name + ": " + VectorEncoding.BYTE); + }; + } + + FieldWriter(FieldInfo fieldInfo) { + super(); + this.fieldInfo = fieldInfo; + this.dim = fieldInfo.getVectorDimension(); + this.docsWithField = new DocsWithFieldSet(); + vectors = new ArrayList<>(); + } + + @Override + public void addValue(int docID, T vectorValue) throws IOException { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)" + ); + } + assert docID > lastDocID; + T copy = copyValue(vectorValue); + docsWithField.add(docID); + vectors.add(copy); + lastDocID = docID; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_RAM_BYTES_USED; + if (vectors.size() == 0) return size; + + int byteSize = fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32 + ? BFloat16.BYTES + : fieldInfo.getVectorEncoding().byteSize; + + return size + docsWithField.ramBytesUsed() + (long) vectors.size() * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) vectors.size() * fieldInfo.getVectorDimension() * byteSize; + } + + @Override + public List getVectors() { + return vectors; + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + this.finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } + } + + static final class FlatCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + + private final RandomVectorScorerSupplier supplier; + private final Closeable onClose; + private final int numVectors; + + FlatCloseableRandomVectorScorerSupplier(Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) { + this.onClose = onClose; + this.supplier = supplier; + this.numVectors = numVectors; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return numVectors; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java index 69f4f96a4e829..2535784bd1004 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. * - * Modifications copyright (C) 2024 Elasticsearch B.V. + * Modifications copyright (C) 2025 Elasticsearch B.V. */ package org.elasticsearch.index.codec.vectors.es93; @@ -25,14 +25,13 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; +import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; import java.io.IOException; -import java.util.Map; /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 @@ -87,33 +86,23 @@ *
  • The sparse vector information, if required, mapping vector ordinal to doc ID * */ -public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsFormat { +public class ES93BinaryQuantizedVectorsFormat extends AbstractFlatVectorsFormat { public static final String NAME = "ES93BinaryQuantizedVectorsFormat"; - private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( - FlatVectorScorerUtil.getLucene99FlatVectorsScorer() - ); - - private static final Map supportedFormats = Map.of( - rawVectorFormat.getName(), - rawVectorFormat - ); - private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); - private final boolean useDirectIO; + private final ES93GenericFlatVectorsFormat rawFormat; public ES93BinaryQuantizedVectorsFormat() { - super(NAME); - this.useDirectIO = false; + this(false, false); } - public ES93BinaryQuantizedVectorsFormat(boolean useDirectIO) { + public ES93BinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) { super(NAME); - this.useDirectIO = useDirectIO; + rawFormat = new ES93GenericFlatVectorsFormat(useBFloat16, useDirectIO); } @Override @@ -122,27 +111,17 @@ protected FlatVectorsScorer flatVectorsScorer() { } @Override - protected boolean useDirectIOReads() { - return useDirectIO; - } - - @Override - protected DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat() { - return rawVectorFormat; - } - - @Override - protected Map supportedReadFlatVectorsFormats() { - return supportedFormats; + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES818BinaryQuantizedVectorsWriter(scorer, rawFormat.fieldsWriter(state), state); } @Override - public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ES818BinaryQuantizedVectorsWriter(scorer, super.fieldsWriter(state), state); + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new ES818BinaryQuantizedVectorsReader(state, rawFormat.fieldsReader(state), scorer); } @Override - public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new ES818BinaryQuantizedVectorsReader(state, super.fieldsReader(state), scorer); + public String toString() { + return getName() + "(name=" + getName() + ", rawVectorFormat=" + rawFormat + ", scorer=" + scorer + ")"; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index 526a4241ed89e..e2026e24506e7 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -9,7 +9,9 @@ package org.elasticsearch.index.codec.vectors.es93; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -19,8 +21,9 @@ import java.io.IOException; import java.util.Map; -public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { +public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { + static final String NAME = "ES93GenericFlatVectorsFormat"; static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi"; static final String META_CODEC_NAME = "ES93GenericFlatVectorsFormatMeta"; @@ -34,28 +37,46 @@ public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFo VERSION_CURRENT ); - public ES93GenericFlatVectorsFormat(String name) { - super(name); - } + private static final FlatVectorsScorer scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer(); + + private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(scorer); + // TODO: a separate scorer for bfloat16 + private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat(scorer); - protected abstract DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat(); + private static final Map supportedFormats = Map.of( + float32VectorFormat.getName(), + float32VectorFormat, + bfloat16VectorFormat.getName(), + bfloat16VectorFormat + ); - protected abstract boolean useDirectIOReads(); + private final DirectIOCapableFlatVectorsFormat writeFormat; + private final boolean useDirectIO; + + public ES93GenericFlatVectorsFormat() { + this(false, false); + } - protected abstract Map supportedReadFlatVectorsFormats(); + public ES93GenericFlatVectorsFormat(boolean useBFloat16, boolean useDirectIO) { + super(NAME); + writeFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; + this.useDirectIO = useDirectIO; + } + + @Override + protected FlatVectorsScorer flatVectorsScorer() { + return scorer; + } @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - var flatFormat = writeFlatVectorsFormat(); - boolean directIO = useDirectIOReads(); - return new ES93GenericFlatVectorsWriter(META, flatFormat.getName(), directIO, state, flatFormat.fieldsWriter(state)); + return new ES93GenericFlatVectorsWriter(META, writeFormat.getName(), useDirectIO, state, writeFormat.fieldsWriter(state)); } @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - var readFormats = supportedReadFlatVectorsFormats(); return new ES93GenericFlatVectorsReader(META, state, (f, dio) -> { - var format = readFormats.get(f); + var format = supportedFormats.get(f); if (format == null) return null; return format.fieldsReader(state, dio); }); @@ -63,13 +84,6 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException @Override public String toString() { - return getName() - + "(name=" - + getName() - + ", writeFlatVectorFormat=" - + writeFlatVectorsFormat() - + ", readFlatVectorsFormats=" - + supportedReadFlatVectorsFormats().values() - + ")"; + return getName() + "(name=" + getName() + ", format=" + writeFormat + ")"; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index b1ade1524e250..c42701f1e5d6f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. * - * Modifications copyright (C) 2024 Elasticsearch B.V. + * Modifications copyright (C) 2025 Elasticsearch B.V. */ package org.elasticsearch.index.codec.vectors.es93; @@ -47,13 +47,13 @@ public ES93HnswBinaryQuantizedVectorsFormat() { /** * Constructs a format using the given graph construction parameters. * - * @param maxConn the maximum number of connections to a node in the HNSW graph - * @param beamWidth the size of the queue maintained during graph construction. + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. * @param useDirectIO whether to use direct IO when reading raw vectors */ - public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useDirectIO) { + public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useBFloat16, boolean useDirectIO) { super(NAME, maxConn, beamWidth); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); } /** @@ -71,11 +71,12 @@ public ES93HnswBinaryQuantizedVectorsFormat( int maxConn, int beamWidth, boolean useDirectIO, + boolean useBFloat16, int numMergeWorkers, ExecutorService mergeExec ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java new file mode 100644 index 0000000000000..42f02d2d21366 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java @@ -0,0 +1,306 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.elasticsearch.index.codec.vectors.BFloat16; + +import java.io.IOException; + +abstract class OffHeapBFloat16VectorValues extends FloatVectorValues { + + protected final int dimension; + protected final int size; + protected final IndexInput slice; + protected final int byteSize; + protected int lastOrd = -1; + protected final byte[] bfloatBytes; + protected final float[] value; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer flatVectorsScorer; + + OffHeapBFloat16VectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + this.dimension = dimension; + this.size = size; + this.slice = slice; + this.byteSize = byteSize; + this.similarityFunction = similarityFunction; + this.flatVectorsScorer = flatVectorsScorer; + bfloatBytes = new byte[dimension * BFloat16.BYTES]; + value = new float[dimension]; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return value; + } + slice.seek((long) targetOrd * byteSize); + // no readShorts() method + slice.readBytes(bfloatBytes, 0, bfloatBytes.length); + BFloat16.bFloat16ToFloat(bfloatBytes, value); + lastOrd = targetOrd; + return value; + } + + static OffHeapBFloat16VectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, + OrdToDocDISIReaderConfiguration configuration, + VectorEncoding vectorEncoding, + int dimension, + int size, + long vectorDataOffset, + long vectorDataLength, + IndexInput vectorData + ) throws IOException { + if (configuration.isEmpty() || vectorEncoding != VectorEncoding.FLOAT32) { + return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); + } + IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); + int byteSize = dimension * BFloat16.BYTES; + if (configuration.isDense()) { + return new DenseOffHeapVectorValues(dimension, size, bytesSlice, byteSize, flatVectorsScorer, vectorSimilarityFunction); + } else { + return new SparseOffHeapVectorValues( + configuration, + vectorData, + bytesSlice, + dimension, + size, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction + ); + } + } + + /** + * Dense vector values that are stored off-heap. This is the most common case when every doc has a + * vector. + */ + static class DenseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public int ordToDoc(int ord) { + return ord; + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.docID()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class SparseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + IndexInput dataIn, + IndexInput slice, + int dimension, + int size, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) throws IOException { + + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dataIn, + slice.clone(), + dimension, + size, + byteSize, + flatVectorsScorer, + similarityFunction + ); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + EmptyOffHeapVectorValues(int dimension, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { + super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); + } + + @Override + public int dimension() { + return super.dimension(); + } + + @Override + public int size() { + return 0; + } + + @Override + public EmptyOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] vectorValue(int targetOrd) { + throw new UnsupportedOperationException(); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index f3cd4f92a6a87..29e6c59d995be 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -60,7 +60,8 @@ public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase LogConfigurator.loadLog4jPlugins(); LogConfigurator.configureESLogging(); // native access requires logging to be initialized } - KnnVectorsFormat format; + + private KnnVectorsFormat format; @Before @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..9ae394733631e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.index.VectorEncoding; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.closeTo; + +public class ES93BinaryQuantizedBFloat16VectorsFormatTests extends ES93BinaryQuantizedVectorsFormatTests { + @Override + boolean useBFloat16() { + return true; + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + public void testEmptyByteVectorData() throws Exception { + // no bytes + } + + @Override + public void testMergingWithDifferentByteKnnFields() throws Exception { + // no bytes + } + + @Override + public void testByteVectorScorerIteration() throws Exception { + // no bytes + } + + @Override + public void testSortedIndexBytes() throws Exception { + // no bytes + } + + @Override + public void testMismatchedFields() throws Exception { + // no bytes + } + + @Override + public void testRandomBytes() throws Exception { + // no bytes + } + + @Override + public void testWriterRamEstimate() throws Exception { + // estimate is different due to bfloat16 + } + + @Override + public void testRandom() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandom); + assertFloatsWithinBounds(err); + } + + @Override + public void testSparseVectors() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); + assertFloatsWithinBounds(err); + } + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); + assertFloatsWithinBounds(err); + } + + private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); + + private static void assertFloatsWithinBounds(AssertionError error) { + Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); + if (m.matches() == false) { + throw error; // nothing to do with us, just rethrow + } + + // numbers just need to be in the same vicinity + double expected = Double.parseDouble(m.group(1)); + double actual = Double.parseDouble(m.group(2)); + double allowedError = expected * 0.01; // within 1% + assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java index 96538fd8dfb74..108739533a76b 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -56,17 +56,15 @@ import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BFloat16; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Locale; -import static java.lang.String.format; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; -import static org.hamcrest.Matchers.either; -import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.Matchers.oneOf; public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { @@ -77,9 +75,13 @@ public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatT private KnnVectorsFormat format; + boolean useBFloat16() { + return false; + } + @Override public void setUp() throws Exception { - format = new ES93BinaryQuantizedVectorsFormat(random().nextBoolean()); + format = new ES93BinaryQuantizedVectorsFormat(useBFloat16(), random().nextBoolean()); super.setUp(); } @@ -191,11 +193,12 @@ public KnnVectorsFormat knnVectorsFormat() { } }; String expectedPattern = "ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat," - + " writeFlatVectorFormat=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat," - + " flatVectorScorer=%s())"; - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); - var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); - assertThat(customCodec.knnVectorsFormat().toString(), either(startsWith(defaultScorer)).or(startsWith(memSegScorer))); + + " rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + + " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer={}()))," + + " scorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate={}()))"; + var defaultScorer = expectedPattern.replaceAll("\\{}", "DefaultFlatVectorScorer"); + var memSegScorer = expectedPattern.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), oneOf(defaultScorer, memSegScorer)); } @Override @@ -239,7 +242,8 @@ public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, b assertEquals(expectVecOffHeap ? 2 : 1, offHeap.size()); assertTrue(offHeap.get("veb") > 0L); if (expectVecOffHeap) { - assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); + int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; + assertEquals(vector.length * bytes, (long) offHeap.get("vec")); } } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..c6f3e555013b3 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.index.VectorEncoding; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.closeTo; + +public class ES93HnswBinaryQuantizedBFloat16VectorsFormatTests extends ES93HnswBinaryQuantizedVectorsFormatTests { + + @Override + boolean useBFloat16() { + return true; + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + public void testEmptyByteVectorData() throws Exception { + // no bytes + } + + @Override + public void testMergingWithDifferentByteKnnFields() throws Exception { + // no bytes + } + + @Override + public void testByteVectorScorerIteration() throws Exception { + // no bytes + } + + @Override + public void testSortedIndexBytes() throws Exception { + // no bytes + } + + @Override + public void testMismatchedFields() throws Exception { + // no bytes + } + + @Override + public void testRandomBytes() throws Exception { + // no bytes + } + + @Override + public void testWriterRamEstimate() throws Exception { + // estimate is different due to bfloat16 + } + + @Override + public void testSingleVectorCase() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSingleVectorCase); + assertFloatsWithinBounds(err); + } + + @Override + public void testRandom() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandom); + assertFloatsWithinBounds(err); + } + + @Override + public void testRandomWithUpdatesAndGraph() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); + assertFloatsWithinBounds(err); + } + + @Override + public void testSparseVectors() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); + assertFloatsWithinBounds(err); + } + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); + assertFloatsWithinBounds(err); + } + + private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); + + private static void assertFloatsWithinBounds(AssertionError error) { + Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); + if (m.matches() == false) { + throw error; // nothing to do with us, just rethrow + } + + // numbers just need to be in the same vicinity + double expected = Double.parseDouble(m.group(1)); + double actual = Double.parseDouble(m.group(2)); + double allowedError = expected * 0.01; // within 1% + assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index 809436f139573..45e489662f3bf 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -46,18 +46,17 @@ import org.apache.lucene.util.SameThreadExecutorService; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BFloat16; import java.io.IOException; import java.util.Arrays; -import java.util.Locale; import static java.lang.String.format; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.hamcrest.Matchers.either; -import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.Matchers.oneOf; public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { @@ -68,9 +67,13 @@ public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFor private KnnVectorsFormat format; + boolean useBFloat16() { + return false; + } + @Override public void setUp() throws Exception { - format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, random().nextBoolean()); + format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, useBFloat16(), random().nextBoolean()); super.setUp(); } @@ -83,17 +86,19 @@ public void testToString() { FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { @Override public KnnVectorsFormat knnVectorsFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, false, 1, null); + return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, false, false, 1, null); } }; - String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," + String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat," + + " maxConn=10, beamWidth=20," + " flatVectorFormat=ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat," - + " writeFlatVectorFormat=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat," - + " flatVectorScorer=%s())"; + + " rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + + " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer={}()))," + + " scorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate={}())))"; - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); - var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); - assertThat(customCodec.knnVectorsFormat().toString(), either(startsWith(defaultScorer)).or(startsWith(memSegScorer))); + var defaultScorer = expectedPattern.replaceAll("\\{}", "DefaultFlatVectorScorer"); + var memSegScorer = expectedPattern.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), oneOf(defaultScorer, memSegScorer)); } public void testSingleVectorCase() throws Exception { @@ -137,15 +142,15 @@ public void testSingleVectorCase() throws Exception { } public void testLimits() { - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, false, false)); expectThrows( IllegalArgumentException.class, - () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, false, 1, new SameThreadExecutorService()) + () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService()) ); } @@ -189,7 +194,8 @@ public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, b assertEquals(1L, (long) offHeap.get("vex")); assertTrue(offHeap.get("veb") > 0L); if (expectVecOffHeap) { - assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); + int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; + assertEquals(vector.length * bytes, (long) offHeap.get("vec")); } } }