diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 6ac0ce7ba99a8..d615a6f3a7895 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -486,4 +486,5 @@ exports org.elasticsearch.index.codec.vectors to org.elasticsearch.test.knn; exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn; exports org.elasticsearch.inference.telemetry; + exports org.elasticsearch.index.codec.vectors.es91 to org.elasticsearch.test.knn; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java index 325188624a2f4..4713349355292 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java @@ -17,7 +17,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -29,6 +28,7 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat; import java.io.IOException; import java.util.Map; @@ -39,7 +39,7 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat { static final String NAME = "ES813FlatVectorFormat"; - private static final FlatVectorsFormat format = new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE); + private static final FlatVectorsFormat format = new ES91BFloat16FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE); /** * Sole constructor diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java index 56710d49b5a7a..f4dded095c606 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java @@ -16,7 +16,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer; -import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter; import org.apache.lucene.index.ByteVectorValues; @@ -34,6 +33,7 @@ import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.QuantizedVectorsReader; import org.apache.lucene.util.quantization.ScalarQuantizer; +import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat; import org.elasticsearch.simdvec.VectorScorerFactory; import org.elasticsearch.simdvec.VectorSimilarityType; @@ -48,7 +48,7 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat { static final String NAME = "ES814ScalarQuantizedVectorsFormat"; private static final int ALLOWED_BITS = (1 << 8) | (1 << 7) | (1 << 4); - private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE); + private static final FlatVectorsFormat rawVectorFormat = new ES91BFloat16FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE); static final FlatVectorsScorer flatVectorScorer = new ESFlatVectorsScorer( new ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java index e3242ee411e7d..5739e44533119 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java @@ -13,7 +13,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentReadState; @@ -24,6 +23,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat; import java.io.IOException; @@ -31,7 +31,7 @@ class ES815BitFlatVectorsFormat extends FlatVectorsFormat { - private static final FlatVectorsFormat delegate = new Lucene99FlatVectorsFormat(FlatBitVectorScorer.INSTANCE); + private static final FlatVectorsFormat delegate = new ES91BFloat16FlatVectorsFormat(FlatBitVectorScorer.INSTANCE); protected ES815BitFlatVectorsFormat() { super("ES815BitFlatVectorsFormat"); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java index 61b6edc474d1f..e34c0576d7cb7 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java @@ -23,9 +23,9 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat; import java.io.IOException; @@ -47,7 +47,7 @@ public class ES816BinaryQuantizedVectorsFormat extends FlatVectorsFormat { static final String VECTOR_DATA_EXTENSION = "veb"; static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; - private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + private static final FlatVectorsFormat rawVectorFormat = new ES91BFloat16FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java index 9a31ff42a7c5c..d425112a16129 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java @@ -23,11 +23,11 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat; import java.io.IOException; @@ -110,7 +110,7 @@ private static boolean getUseDirectIO() { private static final FlatVectorsFormat rawVectorFormat = USE_DIRECT_IO ? new DirectIOLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()) - : new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + : new ES91BFloat16FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/BFloat16.java new file mode 100644 index 0000000000000..729173c7a57f3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/BFloat16.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es91; + +import org.apache.lucene.util.BitUtil; + +class BFloat16 { + + public static final int BYTES = Short.BYTES; + + public static short floatToBFloat16(float f) { + // this does round towards 0 + // zero - zero exp, zero fraction + // denormal - zero exp, non-zero fraction + // infinity - all-1 exp, zero fraction + // NaN - all-1 exp, non-zero fraction + // the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into infinities + return (short) (Float.floatToIntBits(f) >>> 16); + } + + public static float bFloat16ToFloat(short bf) { + return Float.intBitsToFloat(bf << 16); + } + + public static short[] floatToBFloat16(float[] f) { + short[] bf = new short[f.length]; + for (int i = 0; i < f.length; i++) { + bf[i] = floatToBFloat16(f[i]); + } + return bf; + } + + public static float[] bFloat16ToFloat(short[] bf) { + float[] f = new float[bf.length]; + for (int i = 0; i < bf.length; i++) { + f[i] = bFloat16ToFloat(bf[i]); + } + return f; + } + + public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) { + assert floats.length * 2 == bfBytes.length; + for (int i = 0; i < floats.length; i++) { + floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2)); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsFormat.java new file mode 100644 index 0000000000000..4eefbd9783616 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsFormat.java @@ -0,0 +1,65 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es91; + +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +public class ES91BFloat16FlatVectorsFormat extends FlatVectorsFormat { + + static final String NAME = "ES91BFloat16FlatVectorsFormat"; + static final String META_CODEC_NAME = "ES91BFloat16FlatVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES91BFloat16FlatVectorsFormatData"; + static final String META_EXTENSION = "vemf"; + static final String VECTOR_DATA_EXTENSION = "vec"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + private final FlatVectorsScorer vectorsScorer; + + /** Constructs a format */ + public ES91BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) { + super(NAME); + this.vectorsScorer = vectorsScorer; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES91BFloat16FlatVectorsWriter(state, vectorsScorer); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new ES91BFloat16FlatVectorsReader(state, vectorsScorer); + } + + @Override + public String toString() { + return "Lucene99FlatVectorsFormat(" + "vectorsScorer=" + vectorsScorer + ')'; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsReader.java new file mode 100644 index 0000000000000..056e1ad3f1b4f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsReader.java @@ -0,0 +1,344 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es91; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.core.IOUtils; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +public final class ES91BFloat16FlatVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES91BFloat16FlatVectorsFormat.class); + + private final IntObjectHashMap fields = new IntObjectHashMap<>(); + private final IndexInput vectorData; + private final FieldInfos fieldInfos; + + public ES91BFloat16FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; + try { + vectorData = openDataInput( + state, + versionMeta, + ES91BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, + ES91BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Flat formats are used to randomly access vectors from their node ID that is stored + // in the HNSW graph. + state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + ); + } catch (Throwable t) { + IOUtils.closeWhileHandlingException(this); + throw t; + } + } + + private int readMetadata(SegmentReadState state) throws IOException { + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES91BFloat16FlatVectorsFormat.META_EXTENSION + ); + int versionMeta = -1; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + meta, + ES91BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES91BFloat16FlatVectorsFormat.VERSION_START, + ES91BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + } + return versionMeta; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + try { + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + ES91BFloat16FlatVectorsFormat.VERSION_START, + ES91BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + return in; + } catch (Throwable t) { + IOUtils.closeWhileHandlingException(in); + throw t; + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = FieldEntry.create(meta, info); + fields.put(info.number, fieldEntry); + } + } + + @Override + public long ramBytesUsed() { + return ES91BFloat16FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed(); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + final FieldEntry entry = getFieldEntryOrThrow(fieldInfo.name); + return Map.of(ES91BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, entry.vectorDataLength()); + } + + @Override + public void checkIntegrity() throws IOException { + CodecUtil.checksumEntireFile(vectorData); + } + + @Override + public FlatVectorsReader getMergeInstance() { + try { + // Update the read advice since vectors are guaranteed to be accessed sequentially for merge + this.vectorData.updateReadAdvice(ReadAdvice.SEQUENTIAL); + return this; + } catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + private FieldEntry getFieldEntryOrThrow(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry entry; + if (info == null || (entry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } + return entry; + } + + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldEntry fieldEntry = getFieldEntryOrThrow(field); + if (fieldEntry.vectorEncoding != expectedEncoding) { + throw new IllegalArgumentException( + "field=\"" + field + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " + expectedEncoding + ); + } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return OffHeapBFloat16VectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); + return OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, + OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ), + target + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, + OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ), + target + ); + } + + @Override + public void finishMerge() throws IOException { + // This makes sure that the access pattern hint is reverted back since HNSW implementation + // needs it + this.vectorData.updateReadAdvice(ReadAdvice.RANDOM); + } + + @Override + public void close() throws IOException { + IOUtils.close(vectorData); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long vectorDataOffset, + long vectorDataLength, + int dimension, + int size, + OrdToDocDISIReaderConfiguration ordToDoc, + FieldInfo info + ) { + + FieldEntry { + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + int infoVectorDimension = info.getVectorDimension(); + if (infoVectorDimension != dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + info.name + "\"; " + infoVectorDimension + " != " + dimension + ); + } + + int byteSize = switch (info.getVectorEncoding()) { + case BYTE -> Byte.BYTES; + case FLOAT32 -> BFloat16.BYTES; + }; + long vectorBytes = Math.multiplyExact((long) infoVectorDimension, byteSize); + long numBytes = Math.multiplyExact(vectorBytes, size); + if (numBytes != vectorDataLength) { + throw new IllegalStateException( + "Vector data length " + + vectorDataLength + + " not matching size=" + + size + + " * dim=" + + dimension + + " * byteSize=" + + byteSize + + " = " + + numBytes + ); + } + } + + static FieldEntry create(IndexInput input, FieldInfo info) throws IOException { + final VectorEncoding vectorEncoding = readVectorEncoding(input); + final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + final var vectorDataOffset = input.readVLong(); + final var vectorDataLength = input.readVLong(); + final var dimension = input.readVInt(); + final var size = input.readInt(); + final var ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry(similarityFunction, vectorEncoding, vectorDataOffset, vectorDataLength, dimension, size, ordToDoc, info); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsWriter.java new file mode 100644 index 0000000000000..e91c3f69be320 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsWriter.java @@ -0,0 +1,479 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es91; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.core.IOUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +public final class ES91BFloat16FlatVectorsWriter extends FlatVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ES91BFloat16FlatVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final IndexOutput meta, vectorData; + + private final List> fields = new ArrayList<>(); + private boolean finished; + + public ES91BFloat16FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES91BFloat16FlatVectorsFormat.META_EXTENSION + ); + + String vectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES91BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION + ); + + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES91BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES91BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + vectorData, + ES91BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + ES91BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + } catch (Throwable t) { + IOUtils.closeWhileHandlingException(this); + throw t; + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FieldWriter newField = FieldWriter.create(fieldInfo); + fields.add(newField); + return newField; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + for (FieldWriter field : fields) { + if (sortMap == null) { + writeField(field, maxDoc); + } else { + writeSortingField(field, maxDoc, sortMap); + } + field.finish(); + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + total += field.ramBytesUsed(); + } + return total; + } + + private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + switch (fieldData.fieldInfo.getVectorEncoding()) { + case BYTE -> writeByteVectors(fieldData); + case FLOAT32 -> writeBFloat16Vectors(fieldData); + } + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField); + } + + private void writeBFloat16Vectors(FieldWriter fieldData) throws IOException { + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (Object v : fieldData.vectors) { + short[] data = BFloat16.floatToBFloat16((float[]) v); + buffer.asShortBuffer().put(data); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + } + + private void writeByteVectors(FieldWriter fieldData) throws IOException { + for (Object v : fieldData.vectors) { + byte[] vector = (byte[]) v; + vectorData.writeBytes(vector, vector.length); + } + } + + private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) throws IOException { + final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = switch (fieldData.fieldInfo.getVectorEncoding()) { + case BYTE -> writeSortedByteVectors(fieldData, ordMap); + case FLOAT32 -> writeSortedBFloat16Vectors(fieldData, ordMap); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField); + } + + private long writeSortedBFloat16Vectors(FieldWriter fieldData, int[] ordMap) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] vector = (float[]) fieldData.vectors.get(ordinal); + short[] data = BFloat16.floatToBFloat16(vector); + buffer.asShortBuffer().put(data); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + return vectorDataOffset; + } + + private long writeSortedByteVectors(FieldWriter fieldData, int[] ordMap) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + for (int ordinal : ordMap) { + byte[] vector = (byte[]) fieldData.vectors.get(ordinal); + vectorData.writeBytes(vector, vector.length); + } + return vectorDataOffset; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + // Since we know we will not be searching for additional indexing, we can just write the + // the vectors directly to the new segment. + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + // No need to use temporary file as we don't have to re-open for reading + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case BYTE -> writeByteVectorData(vectorData, KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState)); + case FLOAT32 -> writeVectorData(vectorData, KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + IndexOutput tempVectorData = segmentWriteState.directory.createTempOutput(vectorData.getName(), "temp", segmentWriteState.context); + IndexInput vectorDataInput = null; + try { + // write the vector data to a temporary file + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case BYTE -> writeByteVectorData( + tempVectorData, + KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState) + ); + case FLOAT32 -> writeVectorData( + tempVectorData, + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState) + ); + }; + CodecUtil.writeFooter(tempVectorData); + IOUtils.close(tempVectorData); + + // This temp file will be accessed in a random-access fashion to construct the HNSW graph. + // Note: don't use the context from the state, which is a flush/merge context, not expecting + // to perform random reads. + vectorDataInput = segmentWriteState.directory.openInput( + tempVectorData.getName(), + IOContext.DEFAULT.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + ); + // copy the temporary file vectors to the actual data file + vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength()); + CodecUtil.retrieveChecksum(vectorDataInput); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + + final IndexInput finalVectorDataInput = vectorDataInput; + vectorDataInput = null; + + final RandomVectorScorerSupplier randomVectorScorerSupplier = switch (fieldInfo.getVectorEncoding()) { + case BYTE -> vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapByteVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * Byte.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction() + ) + ); + case FLOAT32 -> vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapBFloat16VectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * BFloat16.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction() + ) + ); + }; + return new FlatCloseableRandomVectorScorerSupplier(() -> { + IOUtils.close(finalVectorDataInput); + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + }, docsWithField.cardinality(), randomVectorScorerSupplier); + } catch (Throwable t) { + IOUtils.closeWhileHandlingException(vectorDataInput, tempVectorData); + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + throw t; + } + } + + private void writeMeta(FieldInfo field, int maxDoc, long vectorDataOffset, long vectorDataLength, DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + meta.writeVInt(field.getVectorDimension()); + + // write docIDs + int count = docsWithField.cardinality(); + meta.writeInt(count); + OrdToDocDISIReaderConfiguration.writeStoredMeta(16, meta, vectorData, count, maxDoc, docsWithField); + } + + /** + * Writes the byte vector values to the output and returns a set of documents that contains + * vectors. + */ + private static DocsWithFieldSet writeByteVectorData(IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { + // write vector + byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); + assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; + output.writeBytes(binaryValue, binaryValue.length); + docsWithField.add(docV); + } + return docsWithField; + } + + /** + * Writes the vector values to the output and returns a set of documents that contains vectors. + */ + private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { + // write vector + float[] value = floatVectorValues.vectorValue(iter.index()); + short[] data = BFloat16.floatToBFloat16(value); + buffer.asShortBuffer().put(data); + output.writeBytes(buffer.array(), buffer.limit()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData); + } + + private abstract static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private final int dim; + private final DocsWithFieldSet docsWithField; + private final List vectors; + private boolean finished; + + private int lastDocID = -1; + + static FieldWriter create(FieldInfo fieldInfo) { + int dim = fieldInfo.getVectorDimension(); + return switch (fieldInfo.getVectorEncoding()) { + case BYTE -> new FieldWriter(fieldInfo) { + @Override + public byte[] copyValue(byte[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + case FLOAT32 -> new FieldWriter(fieldInfo) { + @Override + public float[] copyValue(float[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + }; + } + + FieldWriter(FieldInfo fieldInfo) { + super(); + this.fieldInfo = fieldInfo; + this.dim = fieldInfo.getVectorDimension(); + this.docsWithField = new DocsWithFieldSet(); + vectors = new ArrayList<>(); + } + + @Override + public void addValue(int docID, T vectorValue) throws IOException { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)" + ); + } + assert docID > lastDocID; + T copy = copyValue(vectorValue); + docsWithField.add(docID); + vectors.add(copy); + lastDocID = docID; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_RAM_BYTES_USED; + if (vectors.size() == 0) return size; + return size + docsWithField.ramBytesUsed() + (long) vectors.size() * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo + .getVectorEncoding().byteSize; + } + + @Override + public List getVectors() { + return vectors; + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + this.finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } + } + + static final class FlatCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + + private final RandomVectorScorerSupplier supplier; + private final Closeable onClose; + private final int numVectors; + + FlatCloseableRandomVectorScorerSupplier(Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) { + this.onClose = onClose; + this.supplier = supplier; + this.numVectors = numVectors; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return numVectors; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/OffHeapBFloat16VectorValues.java new file mode 100644 index 0000000000000..2d61e9cd66864 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es91/OffHeapBFloat16VectorValues.java @@ -0,0 +1,311 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es91; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; + +import java.io.IOException; + +public abstract class OffHeapBFloat16VectorValues extends FloatVectorValues implements HasIndexSlice { + + protected final int dimension; + protected final int size; + protected final IndexInput slice; + protected final int byteSize; + protected int lastOrd = -1; + protected final byte[] bfloatBytes; + protected final float[] value; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer flatVectorsScorer; + + OffHeapBFloat16VectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + this.dimension = dimension; + this.size = size; + this.slice = slice; + this.byteSize = byteSize; + this.similarityFunction = similarityFunction; + this.flatVectorsScorer = flatVectorsScorer; + bfloatBytes = new byte[dimension * 2]; + value = new float[dimension]; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return value; + } + slice.seek((long) targetOrd * byteSize); + // no readShorts() method + slice.readBytes(bfloatBytes, 0, bfloatBytes.length); + BFloat16.bFloat16ToFloat(bfloatBytes, value); + lastOrd = targetOrd; + return value; + } + + public static OffHeapBFloat16VectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, + OrdToDocDISIReaderConfiguration configuration, + VectorEncoding vectorEncoding, + int dimension, + int size, + long vectorDataOffset, + long vectorDataLength, + IndexInput vectorData + ) throws IOException { + if (configuration.isEmpty() || vectorEncoding != VectorEncoding.FLOAT32) { + return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); + } + IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); + int byteSize = dimension * BFloat16.BYTES; + if (configuration.isDense()) { + return new DenseOffHeapVectorValues(dimension, size, bytesSlice, byteSize, flatVectorsScorer, vectorSimilarityFunction); + } else { + return new SparseOffHeapVectorValues( + configuration, + vectorData, + bytesSlice, + dimension, + size, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction + ); + } + } + + /** + * Dense vector values that are stored off-heap. This is the most common case when every doc has a + * vector. + */ + public static class DenseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public int ordToDoc(int ord) { + return ord; + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.docID()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class SparseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + public SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + IndexInput dataIn, + IndexInput slice, + int dimension, + int size, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) throws IOException { + + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dataIn, + slice.clone(), + dimension, + size, + byteSize, + flatVectorsScorer, + similarityFunction + ); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + public EmptyOffHeapVectorValues(int dimension, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { + super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); + } + + @Override + public int dimension() { + return super.dimension(); + } + + @Override + public int size() { + return 0; + } + + @Override + public EmptyOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] vectorValue(int targetOrd) { + throw new UnsupportedOperationException(); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java index c54903a94b54f..f6dec988be05d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java @@ -22,8 +22,8 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat; import java.io.IOException; @@ -32,7 +32,7 @@ */ public class ES816BinaryQuantizedRWVectorsFormat extends ES816BinaryQuantizedVectorsFormat { - private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + private static final FlatVectorsFormat rawVectorFormat = new ES91BFloat16FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() );