From f105dc88df58f3f63b3b73b46635dc3d3e36d1b7 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 7 Aug 2025 12:02:49 +0100 Subject: [PATCH 01/17] Add ES92 bfloat16 vector format --- server/src/main/java/module-info.java | 2 + .../es818/BinarizedByteVectorValues.java | 2 +- .../es818/ES818BinaryFlatVectorsScorer.java | 4 +- .../ES818BinaryQuantizedVectorsReader.java | 6 +- .../ES818BinaryQuantizedVectorsWriter.java | 20 +- .../vectors/es818/MergeReaderWrapper.java | 4 +- .../es818/OffHeapBinarizedVectorValues.java | 8 +- .../index/codec/vectors/es92/BFloat16.java | 50 ++ .../es92/ES92BFloat16FlatVectorsFormat.java | 129 ++++ .../es92/ES92BFloat16FlatVectorsReader.java | 336 ++++++++ .../es92/ES92BFloat16FlatVectorsWriter.java | 478 ++++++++++++ ...2BinaryQuantizedBFloat16VectorsFormat.java | 133 ++++ ...2BinaryQuantizedBFloat16VectorsReader.java | 391 ++++++++++ ...2BinaryQuantizedBFloat16VectorsWriter.java | 726 ++++++++++++++++++ ...wBinaryQuantizedBFloat16VectorsFormat.java | 145 ++++ .../es92/OffHeapBFloat16VectorValues.java | 311 ++++++++ .../vectors/DenseVectorFieldMapper.java | 51 +- .../org.apache.lucene.codecs.KnnVectorsFormat | 2 + ...ryQuantizedBFloat16VectorsFormatTests.java | 349 +++++++++ ...ryQuantizedBFloat16VectorsFormatTests.java | 252 ++++++ .../vectors/DenseVectorFieldTypeTests.java | 13 +- .../vectors/RescoreKnnVectorQueryTests.java | 4 + .../mapper/SemanticTextFieldMapper.java | 2 +- .../mapper/SemanticTextFieldMapperTests.java | 2 +- 24 files changed, 3388 insertions(+), 32 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsReader.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsWriter.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 6ac0ce7ba99a8..d0ac9a17b0779 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -459,6 +459,8 @@ org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat, + org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat, org.elasticsearch.index.codec.vectors.IVFVectorsFormat; provides org.apache.lucene.codecs.Codec diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java index ca80ba52e2c2b..53867d9d1e494 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java @@ -30,7 +30,7 @@ /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ -abstract class BinarizedByteVectorValues extends ByteVectorValues { +public abstract class BinarizedByteVectorValues extends ByteVectorValues { /** * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index efb098373489f..bbc3d328d7767 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -108,7 +108,7 @@ public RandomVectorScorer getRandomVectorScorer( return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } - RandomVectorScorerSupplier getRandomVectorScorerSupplier( + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, BinarizedByteVectorValues targetVectors @@ -122,7 +122,7 @@ public String toString() { } /** Vector scorer supplier over binarized vector values */ - static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + public static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { private final ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; private final BinarizedByteVectorValues targetVectors; private final VectorSimilarityFunction similarityFunction; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java index e59b9ba7e6535..0a59a69416182 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java @@ -388,11 +388,11 @@ static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, Vector } /** Binarized vector values holding row and quantized vector values */ - protected static final class BinarizedVectorValues extends FloatVectorValues { + public static final class BinarizedVectorValues extends FloatVectorValues { private final FloatVectorValues rawVectorValues; private final BinarizedByteVectorValues quantizedVectorValues; - BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { + public BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { this.rawVectorValues = rawVectorValues; this.quantizedVectorValues = quantizedVectorValues; } @@ -437,7 +437,7 @@ public VectorScorer scorer(float[] query) throws IOException { return quantizedVectorValues.scorer(query); } - BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { + public BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { return quantizedVectorValues; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java index 22520567f2954..12cf4742b6b47 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -723,7 +723,7 @@ public long ramBytesUsed() { } // When accessing vectorValue method, targerOrd here means a row ordinal. - static class OffHeapBinarizedQueryVectorValues { + public static class OffHeapBinarizedQueryVectorValues { private final IndexInput slice; private final int dimension; private final int size; @@ -734,7 +734,7 @@ static class OffHeapBinarizedQueryVectorValues { private int lastOrd = -1; private int quantizedComponentSum; - OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) { + public OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) { this.slice = data; this.dimension = dimension; this.size = size; @@ -798,7 +798,7 @@ public byte[] vectorValue(int targetOrd) throws IOException { } } - static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { + public static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { private OptimizedScalarQuantizer.QuantizationResult corrections; private final byte[] binarized; private final int[] initQuantized; @@ -808,7 +808,7 @@ static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { private int lastOrd = -1; - BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { + public BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { this.values = delegate; this.quantizer = quantizer; this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8]; @@ -881,12 +881,16 @@ public int ordToDoc(int ord) { } } - static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + public static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { private final RandomVectorScorerSupplier supplier; private final KnnVectorValues vectorValues; private final Closeable onClose; - BinarizedCloseableRandomVectorScorerSupplier(RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { + public BinarizedCloseableRandomVectorScorerSupplier( + RandomVectorScorerSupplier supplier, + KnnVectorValues vectorValues, + Closeable onClose + ) { this.supplier = supplier; this.onClose = onClose; this.vectorValues = vectorValues; @@ -913,11 +917,11 @@ public int totalVectorCount() { } } - static final class NormalizedFloatVectorValues extends FloatVectorValues { + public static final class NormalizedFloatVectorValues extends FloatVectorValues { private final FloatVectorValues values; private final float[] normalizedVector; - NormalizedFloatVectorValues(FloatVectorValues values) { + public NormalizedFloatVectorValues(FloatVectorValues values) { this.values = values; this.normalizedVector = new float[values.dimension()]; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java index e26784ecfdd96..915f3b4ced18e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java @@ -23,12 +23,12 @@ import java.util.Collection; import java.util.Map; -class MergeReaderWrapper extends FlatVectorsReader { +public class MergeReaderWrapper extends FlatVectorsReader { private final FlatVectorsReader mainReader; private final FlatVectorsReader mergeReader; - protected MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) { + public MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) { super(mainReader.getFlatVectorScorer()); this.mainReader = mainReader; this.mergeReader = mergeReader; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java index 0357468c6864d..8d1474b7ecf31 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java @@ -36,7 +36,7 @@ import java.nio.ByteBuffer; /** Binarized vector values loaded from off-heap */ -abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { +public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { final int dimension; final int size; @@ -151,7 +151,7 @@ public int getVectorByteLength() { return numBytes; } - static OffHeapBinarizedVectorValues load( + public static OffHeapBinarizedVectorValues load( OrdToDocDISIReaderConfiguration configuration, int dimension, int size, @@ -197,8 +197,8 @@ static OffHeapBinarizedVectorValues load( } /** Dense off-heap binarized vector values */ - static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { - DenseOffHeapVectorValues( + public static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + public DenseOffHeapVectorValues( int dimension, int size, float[] centroid, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java new file mode 100644 index 0000000000000..639ea6e8eed5a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.util.BitUtil; + +import java.nio.ByteOrder; +import java.nio.ShortBuffer; + +class BFloat16 { + + public static final int BYTES = Short.BYTES; + + public static short floatToBFloat16(float f) { + // this rounds towards 0 + // zero - zero exp, zero fraction + // denormal - zero exp, non-zero fraction + // infinity - all-1 exp, zero fraction + // NaN - all-1 exp, non-zero fraction + // the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into + // infinities + return (short) (Float.floatToIntBits(f) >>> 16); + } + + public static float bFloat16ToFloat(short bf) { + return Float.intBitsToFloat(bf << 16); + } + + public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) { + assert bFloats.remaining() == floats.length; + assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; + for (float v : floats) { + bFloats.put(floatToBFloat16(v)); + } + } + + public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) { + assert floats.length * 2 == bfBytes.length; + for (int i = 0; i < floats.length; i++) { + floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2)); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java new file mode 100644 index 0000000000000..859001dd218b2 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java @@ -0,0 +1,129 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.FlushInfo; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.MergeInfo; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.MergeReaderWrapper; +import org.elasticsearch.index.store.FsDirectoryFactory; + +import java.io.IOException; +import java.util.Set; + +public final class ES92BFloat16FlatVectorsFormat extends FlatVectorsFormat { + + static final String NAME = "ES92BFloat16FlatVectorsFormat"; + static final String META_CODEC_NAME = "ES92BFloat16FlatVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES92BFloat16FlatVectorsFormatData"; + static final String META_EXTENSION = "vemf"; + static final String VECTOR_DATA_EXTENSION = "vec"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + private final FlatVectorsScorer vectorsScorer; + + public ES92BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) { + super(NAME); + this.vectorsScorer = vectorsScorer; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES92BFloat16FlatVectorsWriter(state, vectorsScorer); + } + + static boolean shouldUseDirectIO(SegmentReadState state) { + return ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO && FsDirectoryFactory.isHybridFs(state.directory); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + if (shouldUseDirectIO(state) && state.context.context() == IOContext.Context.DEFAULT) { + // only override the context for the random-access use case + SegmentReadState directIOState = new SegmentReadState( + state.directory, + state.segmentInfo, + state.fieldInfos, + new DirectIOContext(state.context.hints()), + state.segmentSuffix + ); + // Use mmap for merges and direct I/O for searches. + // TODO: Open the mmap file with sequential access instead of random (current behavior). + return new MergeReaderWrapper( + new ES92BFloat16FlatVectorsReader(directIOState, vectorsScorer), + new ES92BFloat16FlatVectorsReader(state, vectorsScorer) + ); + } else { + return new ES92BFloat16FlatVectorsReader(state, vectorsScorer); + } + } + + @Override + public String toString() { + return "ES92BFloat16FlatVectorsFormat(" + "vectorsScorer=" + vectorsScorer + ')'; + } + + static class DirectIOContext implements IOContext { + + final Set hints; + + DirectIOContext(Set hints) { + // always add DirectIOHint to the hints given + this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE)); + } + + @Override + public Context context() { + return Context.DEFAULT; + } + + @Override + public MergeInfo mergeInfo() { + return null; + } + + @Override + public FlushInfo flushInfo() { + return null; + } + + @Override + public Set hints() { + return hints; + } + + @Override + public IOContext withHints(FileOpenHint... hints) { + return new DirectIOContext(Set.of(hints)); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java new file mode 100644 index 0000000000000..6c4bae4beb2c3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java @@ -0,0 +1,336 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.*; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.store.*; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +public final class ES92BFloat16FlatVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES92BFloat16FlatVectorsReader.class); + + private final IntObjectHashMap fields = new IntObjectHashMap<>(); + private final IndexInput vectorData; + private final FieldInfos fieldInfos; + + public ES92BFloat16FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; + boolean success = false; + try { + vectorData = openDataInput( + state, + versionMeta, + ES92BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, + ES92BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Flat formats are used to randomly access vectors from their node ID that is stored + // in the HNSW graph. + state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private int readMetadata(SegmentReadState state) throws IOException { + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES92BFloat16FlatVectorsFormat.META_EXTENSION + ); + int versionMeta = -1; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + meta, + ES92BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES92BFloat16FlatVectorsFormat.VERSION_START, + ES92BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + } + return versionMeta; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + ES92BFloat16FlatVectorsFormat.VERSION_START, + ES92BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = FieldEntry.create(meta, info); + fields.put(info.number, fieldEntry); + } + } + + @Override + public long ramBytesUsed() { + return ES92BFloat16FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed(); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + final FieldEntry entry = getFieldEntryOrThrow(fieldInfo.name); + return Map.of(ES92BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, entry.vectorDataLength()); + } + + @Override + public void checkIntegrity() throws IOException { + CodecUtil.checksumEntireFile(vectorData); + } + + @Override + public FlatVectorsReader getMergeInstance() { + try { + // Update the read advice since vectors are guaranteed to be accessed sequentially for merge + this.vectorData.updateReadAdvice(ReadAdvice.SEQUENTIAL); + return this; + } catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + private FieldEntry getFieldEntryOrThrow(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry entry; + if (info == null || (entry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } + return entry; + } + + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldEntry fieldEntry = getFieldEntryOrThrow(field); + if (fieldEntry.vectorEncoding != expectedEncoding) { + throw new IllegalArgumentException( + "field=\"" + field + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " + expectedEncoding + ); + } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return OffHeapBFloat16VectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); + return OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, + OffHeapBFloat16VectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ), + target + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, + OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ), + target + ); + } + + @Override + public void finishMerge() throws IOException { + // This makes sure that the access pattern hint is reverted back since HNSW implementation + // needs it + this.vectorData.updateReadAdvice(ReadAdvice.RANDOM); + } + + @Override + public void close() throws IOException { + IOUtils.close(vectorData); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long vectorDataOffset, + long vectorDataLength, + int dimension, + int size, + OrdToDocDISIReaderConfiguration ordToDoc, + FieldInfo info + ) { + + FieldEntry { + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + int infoVectorDimension = info.getVectorDimension(); + if (infoVectorDimension != dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + info.name + "\"; " + infoVectorDimension + " != " + dimension + ); + } + + int byteSize = switch (info.getVectorEncoding()) { + case BYTE -> Byte.BYTES; + case FLOAT32 -> BFloat16.BYTES; + }; + long vectorBytes = Math.multiplyExact((long) infoVectorDimension, byteSize); + long numBytes = Math.multiplyExact(vectorBytes, size); + if (numBytes != vectorDataLength) { + throw new IllegalStateException( + "Vector data length " + + vectorDataLength + + " not matching size=" + + size + + " * dim=" + + dimension + + " * byteSize=" + + byteSize + + " = " + + numBytes + ); + } + } + + static FieldEntry create(IndexInput input, FieldInfo info) throws IOException { + final VectorEncoding vectorEncoding = readVectorEncoding(input); + final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + final var vectorDataOffset = input.readVLong(); + final var vectorDataLength = input.readVLong(); + final var dimension = input.readVInt(); + final var size = input.readInt(); + final var ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry(similarityFunction, vectorEncoding, vectorDataOffset, vectorDataLength, dimension, size, ordToDoc, info); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java new file mode 100644 index 0000000000000..ce0ad8387e24f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java @@ -0,0 +1,478 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.elasticsearch.index.codec.vectors.es92.ES92BFloat16FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; + +public final class ES92BFloat16FlatVectorsWriter extends FlatVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ES92BFloat16FlatVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final IndexOutput meta, vectorData; + + private final List> fields = new ArrayList<>(); + private boolean finished; + + public ES92BFloat16FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES92BFloat16FlatVectorsFormat.META_EXTENSION + ); + + String vectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES92BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION + ); + + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES92BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES92BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + vectorData, + ES92BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + ES92BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FieldWriter newField = FieldWriter.create(fieldInfo); + fields.add(newField); + return newField; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + for (FieldWriter field : fields) { + if (sortMap == null) { + writeField(field, maxDoc); + } else { + writeSortingField(field, maxDoc, sortMap); + } + field.finish(); + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + total += field.ramBytesUsed(); + } + return total; + } + + private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + switch (fieldData.fieldInfo.getVectorEncoding()) { + case BYTE -> writeByteVectors(fieldData); + case FLOAT32 -> writeBFloat16Vectors(fieldData); + } + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField); + } + + private void writeBFloat16Vectors(FieldWriter fieldData) throws IOException { + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (Object v : fieldData.vectors) { + BFloat16.floatToBFloat16((float[]) v, buffer.asShortBuffer()); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + } + + private void writeByteVectors(FieldWriter fieldData) throws IOException { + for (Object v : fieldData.vectors) { + byte[] vector = (byte[]) v; + vectorData.writeBytes(vector, vector.length); + } + } + + private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) throws IOException { + final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = switch (fieldData.fieldInfo.getVectorEncoding()) { + case BYTE -> writeSortedByteVectors(fieldData, ordMap); + case FLOAT32 -> writeSortedBFloat16Vectors(fieldData, ordMap); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField); + } + + private long writeSortedBFloat16Vectors(FieldWriter fieldData, int[] ordMap) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] vector = (float[]) fieldData.vectors.get(ordinal); + BFloat16.floatToBFloat16(vector, buffer.asShortBuffer()); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + return vectorDataOffset; + } + + private long writeSortedByteVectors(FieldWriter fieldData, int[] ordMap) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + for (int ordinal : ordMap) { + byte[] vector = (byte[]) fieldData.vectors.get(ordinal); + vectorData.writeBytes(vector, vector.length); + } + return vectorDataOffset; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + // Since we know we will not be searching for additional indexing, we can just write the + // the vectors directly to the new segment. + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + // No need to use temporary file as we don't have to re-open for reading + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case BYTE -> writeByteVectorData(vectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState)); + case FLOAT32 -> writeVectorData(vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + IndexOutput tempVectorData = segmentWriteState.directory.createTempOutput(vectorData.getName(), "temp", segmentWriteState.context); + IndexInput vectorDataInput = null; + boolean success = false; + try { + // write the vector data to a temporary file + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case BYTE -> writeByteVectorData(tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState)); + case FLOAT32 -> writeVectorData(tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + }; + CodecUtil.writeFooter(tempVectorData); + IOUtils.close(tempVectorData); + + // This temp file will be accessed in a random-access fashion to construct the HNSW graph. + // Note: don't use the context from the state, which is a flush/merge context, not expecting + // to perform random reads. + vectorDataInput = segmentWriteState.directory.openInput( + tempVectorData.getName(), + IOContext.DEFAULT.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + ); + // copy the temporary file vectors to the actual data file + vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength()); + CodecUtil.retrieveChecksum(vectorDataInput); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + success = true; + final IndexInput finalVectorDataInput = vectorDataInput; + final RandomVectorScorerSupplier randomVectorScorerSupplier = switch (fieldInfo.getVectorEncoding()) { + case BYTE -> vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapByteVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * Byte.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction() + ) + ); + case FLOAT32 -> vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapFloatVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * Float.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction() + ) + ); + }; + return new FlatCloseableRandomVectorScorerSupplier(() -> { + IOUtils.close(finalVectorDataInput); + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + }, docsWithField.cardinality(), randomVectorScorerSupplier); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(vectorDataInput, tempVectorData); + IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempVectorData.getName()); + } + } + } + + private void writeMeta(FieldInfo field, int maxDoc, long vectorDataOffset, long vectorDataLength, DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + meta.writeVInt(field.getVectorDimension()); + + // write docIDs + int count = docsWithField.cardinality(); + meta.writeInt(count); + OrdToDocDISIReaderConfiguration.writeStoredMeta(DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); + } + + /** + * Writes the byte vector values to the output and returns a set of documents that contains + * vectors. + */ + private static DocsWithFieldSet writeByteVectorData(IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { + // write vector + byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); + assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; + output.writeBytes(binaryValue, binaryValue.length); + docsWithField.add(docV); + } + return docsWithField; + } + + /** + * Writes the vector values to the output and returns a set of documents that contains vectors. + */ + private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { + // write vector + float[] value = floatVectorValues.vectorValue(iter.index()); + BFloat16.floatToBFloat16(value, buffer.asShortBuffer()); + output.writeBytes(buffer.array(), buffer.limit()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData); + } + + private abstract static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private final int dim; + private final DocsWithFieldSet docsWithField; + private final List vectors; + private boolean finished; + + private int lastDocID = -1; + + static FieldWriter create(FieldInfo fieldInfo) { + int dim = fieldInfo.getVectorDimension(); + return switch (fieldInfo.getVectorEncoding()) { + case BYTE -> new ES92BFloat16FlatVectorsWriter.FieldWriter(fieldInfo) { + @Override + public byte[] copyValue(byte[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + case FLOAT32 -> new ES92BFloat16FlatVectorsWriter.FieldWriter(fieldInfo) { + @Override + public float[] copyValue(float[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + }; + } + + FieldWriter(FieldInfo fieldInfo) { + super(); + this.fieldInfo = fieldInfo; + this.dim = fieldInfo.getVectorDimension(); + this.docsWithField = new DocsWithFieldSet(); + vectors = new ArrayList<>(); + } + + @Override + public void addValue(int docID, T vectorValue) throws IOException { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)" + ); + } + assert docID > lastDocID; + T copy = copyValue(vectorValue); + docsWithField.add(docID); + vectors.add(copy); + lastDocID = docID; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_RAM_BYTES_USED; + if (vectors.size() == 0) return size; + + int byteSize = fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32 + ? BFloat16.BYTES + : fieldInfo.getVectorEncoding().byteSize; + + return size + docsWithField.ramBytesUsed() + (long) vectors.size() * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) vectors.size() * fieldInfo.getVectorDimension() * byteSize; + } + + @Override + public List getVectors() { + return vectors; + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + this.finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } + } + + static final class FlatCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + + private final RandomVectorScorerSupplier supplier; + private final Closeable onClose; + private final int numVectors; + + FlatCloseableRandomVectorScorerSupplier(Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) { + this.onClose = onClose; + this.supplier = supplier; + this.numVectors = numVectors; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return numVectors; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java new file mode 100644 index 0000000000000..642ec6807a93c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java @@ -0,0 +1,133 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; + +import java.io.IOException; + +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + * Codec for encoding/decoding binary quantized vectors The binary quantization format used here + * is a per-vector optimized scalar quantization. Also see {@link + * OptimizedScalarQuantizer}. Some of key features are: + * + *
    + *
  • Estimating the distance between two vectors using their centroid normalized distance. This + * requires some additional corrective factors, but allows for centroid normalization to occur. + *
  • Optimized scalar quantization to bit level of centroid normalized vectors. + *
  • Asymmetric quantization of vectors, where query vectors are quantized to half-byte + * precision (normalized to the centroid) and then compared directly against the single bit + * quantized vectors in the index. + *
  • Transforming the half-byte quantized query vectors in such a way that the comparison with + * single bit vectors can be done with bit arithmetic. + *
+ * + * The format is stored in two files: + * + *

.veb (vector data) file

+ * + *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

    + *
  • For each vector: + *
      + *
    • [byte] the binary quantized values, each byte holds 8 bits. + *
    • [float] the optimized quantiles and an additional similarity dependent corrective factor. + *
    • short the sum of the quantized components
    • + *
    + *
  • After the vectors, sparse vector information keeping track of monotonic blocks. + *
+ * + *

.vemb (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

    + *
  • int the field number + *
  • int the vector encoding ordinal + *
  • int the vector similarity ordinal + *
  • vint the vector dimensions + *
  • vlong the offset to the vector data in the .veb file + *
  • vlong the length of the vector data in the .veb file + *
  • vint the number of vectors + *
  • [float] the centroid
  • + *
  • float the centroid square magnitude
  • + *
  • The sparse vector information, if required, mapping vector ordinal to doc ID + *
+ */ +public class ES92BinaryQuantizedBFloat16VectorsFormat extends FlatVectorsFormat { + + public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; + public static final String NAME = "ES92BinaryQuantizedBFloat16VectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "ES92BinaryQuantizedBFloat16VectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES92BinaryQuantizedBFloat16VectorsFormatData"; + static final String META_EXTENSION = "vemb"; + static final String VECTOR_DATA_EXTENSION = "veb"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = new ES92BFloat16FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + /** Creates a new instance with the default number of vectors per cluster. */ + public ES92BinaryQuantizedBFloat16VectorsFormat() { + super(NAME); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES92BinaryQuantizedBFloat16VectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new ES92BinaryQuantizedBFloat16VectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES92BinaryQuantizedBFloat16VectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsReader.java new file mode 100644 index 0000000000000..d3f43851fe488 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsReader.java @@ -0,0 +1,391 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.SuppressForbidden; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; +import org.elasticsearch.index.codec.vectors.es818.OffHeapBinarizedVectorValues; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; +import static org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat.VECTOR_DATA_EXTENSION; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +@SuppressForbidden(reason = "Lucene classes") +public class ES92BinaryQuantizedBFloat16VectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES92BinaryQuantizedBFloat16VectorsReader.class); + + private final Map fields; + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + private final ES818BinaryFlatVectorsScorer vectorScorer; + + @SuppressWarnings("this-escape") + ES92BinaryQuantizedBFloat16VectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + ES818BinaryFlatVectorsScorer vectorsScorer + ) throws IOException { + super(vectorsScorer); + this.fields = new HashMap<>(); + this.vectorScorer = vectorsScorer; + this.rawVectorsReader = rawVectorsReader; + int versionMeta = -1; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES92BinaryQuantizedBFloat16VectorsFormat.META_EXTENSION + ); + boolean success = false; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + meta, + ES92BinaryQuantizedBFloat16VectorsFormat.META_CODEC_NAME, + ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_START, + ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + quantizedVectorData = openDataInput( + state, + versionMeta, + VECTOR_DATA_EXTENSION, + ES92BinaryQuantizedBFloat16VectorsFormat.VECTOR_DATA_CODEC_NAME, + // Quantized vectors are accessed randomly from their node ID stored in the HNSW + // graph. + state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private ES92BinaryQuantizedBFloat16VectorsReader(ES92BinaryQuantizedBFloat16VectorsReader clone, FlatVectorsReader rawVectorsReader) { + super(clone.vectorScorer); + this.rawVectorsReader = rawVectorsReader; + this.vectorScorer = clone.vectorScorer; + this.quantizedVectorData = clone.quantizedVectorData; + this.fields = clone.fields; + } + + @Override + public FlatVectorsReader getMergeInstance() { + return new ES92BinaryQuantizedBFloat16VectorsReader(this, rawVectorsReader.getMergeInstance()); + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta, info); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + info.name + "\"; " + dimension + " != " + fieldEntry.dimension + ); + } + + int binaryDims = BQVectorUtils.discretize(dimension, 64) / 8; + long numQuantizedVectorBytes = Math.multiplyExact((binaryDims + (Float.BYTES * 3) + Short.BYTES), (long) fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "Binarized vector data length " + + fieldEntry.vectorDataLength + + " not matching size = " + + fieldEntry.size + + " * (binaryBytes=" + + binaryDims + + " + 14" + + ") = " + + numQuantizedVectorBytes + ); + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null || fi.size() == 0) { + return null; + } + return vectorScorer.getRandomVectorScorer( + fi.similarityFunction, + OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData + ), + target + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + field + "\" is encoded as: " + fi.vectorEncoding + " expected: " + VectorEncoding.FLOAT32 + ); + } + OffHeapBinarizedVectorValues bvv = OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData + ); + return new ES818BinaryQuantizedVectorsReader.BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + if (knnCollector.k() == 0) return; + final RandomVectorScorer scorer = getRandomVectorScorer(field, target); + if (scorer == null) return; + OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + collector.collect(i, scorer.score(i)); + collector.incVisitedCount(1); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOfMap(fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + var raw = rawVectorsReader.getOffHeapByteSize(fieldInfo); + FieldEntry fe = fields.get(fieldInfo.name); + if (fe == null) { + assert fieldInfo.getVectorEncoding() == VectorEncoding.BYTE; + return raw; + } + var quant = Map.of(VECTOR_DATA_EXTENSION, fe.vectorDataLength()); + return KnnVectorsReader.mergeOffHeapByteSizeMaps(raw, quant); + } + + public float[] getCentroid(String field) { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry != null) { + return fieldEntry.centroid; + } + return null; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_START, + ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + int dimension, + int descritizedDimension, + long vectorDataOffset, + long vectorDataLength, + int size, + float[] centroid, + float centroidDP, + OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration + ) { + + static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) + throws IOException { + int dimension = input.readVInt(); + long vectorDataOffset = input.readVLong(); + long vectorDataLength = input.readVLong(); + int size = input.readVInt(); + final float[] centroid; + float centroidDP = 0; + if (size > 0) { + centroid = new float[dimension]; + input.readFloats(centroid, 0, dimension); + centroidDP = Float.intBitsToFloat(input.readInt()); + } else { + centroid = null; + } + OrdToDocDISIReaderConfiguration conf = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + dimension, + BQVectorUtils.discretize(dimension, 64), + vectorDataOffset, + vectorDataLength, + size, + centroid, + centroidDP, + conf + ); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsWriter.java new file mode 100644 index 0000000000000..c071029d6e99d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsWriter.java @@ -0,0 +1,726 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.es818.BinarizedByteVectorValues; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; +import org.elasticsearch.index.codec.vectors.es818.OffHeapBinarizedVectorValues; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat.BINARIZED_VECTOR_COMPONENT; +import static org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +@SuppressForbidden(reason = "Lucene classes") +public class ES92BinaryQuantizedBFloat16VectorsWriter extends FlatVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = shallowSizeOfInstance(ES92BinaryQuantizedBFloat16VectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, binarizedVectorData; + private final FlatVectorsWriter rawVectorDelegate; + private final ES818BinaryFlatVectorsScorer vectorsScorer; + private boolean finished; + + /** + * Sole constructor + * + * @param vectorsScorer the scorer to use for scoring vectors + */ + @SuppressWarnings("this-escape") + protected ES92BinaryQuantizedBFloat16VectorsWriter( + ES818BinaryFlatVectorsScorer vectorsScorer, + FlatVectorsWriter rawVectorDelegate, + SegmentWriteState state + ) throws IOException { + super(vectorsScorer); + this.vectorsScorer = vectorsScorer; + this.segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES92BinaryQuantizedBFloat16VectorsFormat.META_EXTENSION + ); + + String binarizedVectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES92BinaryQuantizedBFloat16VectorsFormat.VECTOR_DATA_EXTENSION + ); + this.rawVectorDelegate = rawVectorDelegate; + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + binarizedVectorData = state.directory.createOutput(binarizedVectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES92BinaryQuantizedBFloat16VectorsFormat.META_CODEC_NAME, + ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + binarizedVectorData, + ES92BinaryQuantizedBFloat16VectorsFormat.VECTOR_DATA_CODEC_NAME, + ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FieldWriter fieldWriter = new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); + fields.add(fieldWriter); + return fieldWriter; + } + return rawVectorDelegate; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + // after raw vectors are written, normalize vectors for clustering and quantization + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + field.normalizeVectors(); + } + final float[] clusterCenter; + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } + } + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction()); + if (sortMap == null) { + writeField(field, clusterCenter, maxDoc, quantizer); + } else { + writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + private void writeField(FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer) + throws IOException { + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeBinarizedVectors(fieldData, clusterCenter, quantizer); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + clusterCenter, + centroidDp, + fieldData.getDocsWithFieldSet() + ); + } + + private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); + int[] quantizationScratch = new int[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int i = 0; i < fieldData.getVectors().size(); i++) { + float[] v = fieldData.getVectors().get(i); + OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( + v, + quantizationScratch, + (byte) 1, + clusterCenter + ); + BQVectorUtils.packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeSortingField( + FieldWriter fieldData, + float[] clusterCenter, + int maxDoc, + Sorter.DocMap sortMap, + OptimizedScalarQuantizer scalarQuantizer + ) throws IOException { + final int[] ordMap = new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); + long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + + float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, quantizedVectorLength, clusterCenter, centroidDp, newDocsWithField); + } + + private void writeSortedBinarizedVectors( + FieldWriter fieldData, + float[] clusterCenter, + int[] ordMap, + OptimizedScalarQuantizer scalarQuantizer + ) throws IOException { + int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); + int[] quantizationScratch = new int[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( + v, + quantizationScratch, + (byte) 1, + clusterCenter + ); + BQVectorUtils.packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + float[] clusterCenter, + float centroidDp, + DocsWithFieldSet docsWithField + ) throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + int count = docsWithField.cardinality(); + meta.writeVInt(count); + if (count > 0) { + final ByteBuffer buffer = ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + meta.writeInt(Float.floatToIntBits(centroidDp)); + } + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, + meta, + binarizedVectorData, + count, + maxDoc, + docsWithField + ); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (binarizedVectorData != null) { + CodecUtil.writeFooter(binarizedVectorData); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + FloatVectorValues floatVectorValues = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(floatVectorValues); + } + ES818BinaryQuantizedVectorsWriter.BinarizedFloatVectorValues binarizedVectorValues = + new ES818BinaryQuantizedVectorsWriter.BinarizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + centroid + ); + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + centroidDp, + docsWithField + ); + } else { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + } + } + + static DocsWithFieldSet writeBinarizedVectorAndQueryData( + IndexOutput binarizedVectorData, + IndexOutput binarizedQueryData, + FloatVectorValues floatVectorValues, + float[] centroid, + OptimizedScalarQuantizer binaryQuantizer + ) throws IOException { + int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64); + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + int[][] quantizationScratch = new int[2][floatVectorValues.dimension()]; + byte[] toIndex = new byte[discretizedDimension / 8]; + byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY]; + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write index vector + OptimizedScalarQuantizer.QuantizationResult[] r = binaryQuantizer.multiScalarQuantize( + floatVectorValues.vectorValue(iterator.index()), + quantizationScratch, + new byte[] { 1, 4 }, + centroid + ); + // pack and store document bit vector + BQVectorUtils.packAsBinary(quantizationScratch[0], toIndex); + binarizedVectorData.writeBytes(toIndex, toIndex.length); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].additionalCorrection())); + assert r[0].quantizedComponentSum() >= 0 && r[0].quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) r[0].quantizedComponentSum()); + docsWithField.add(docV); + + // pack and store the 4bit query vector + BQSpaceUtils.transposeHalfByte(quantizationScratch[1], toQuery); + binarizedQueryData.writeBytes(toQuery, toQuery.length); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].lowerInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].upperInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].additionalCorrection())); + assert r[1].quantizedComponentSum() >= 0 && r[1].quantizedComponentSum() <= 0xffff; + binarizedQueryData.writeShort((short) r[1].quantizedComponentSum()); + } + return docsWithField; + } + + static DocsWithFieldSet writeBinarizedVectorData(IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) + throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write vector + byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index()); + output.writeBytes(binaryValue, binaryValue.length); + OptimizedScalarQuantizer.QuantizationResult corrections = binarizedByteVectorValues.getCorrectiveTerms(iterator.index()); + output.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + output.writeInt(Float.floatToIntBits(corrections.upperInterval())); + output.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + output.writeShort((short) corrections.quantizedComponentSum()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float cDotC; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC); + } + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + float[] centroid, + float cDotC + ) throws IOException { + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + IndexInput binarizedDataInput = null; + IndexInput binarizedScoreDataInput = null; + IndexOutput tempQuantizedVectorData = null; + IndexOutput tempScoreQuantizedVectorData = null; + boolean success = false; + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + try { + // Since we are opening two files, it's possible that one or the other fails to open + // we open them within the try to ensure they are cleaned + tempQuantizedVectorData = segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), + "temp", + segmentWriteState.context + ); + tempScoreQuantizedVectorData = segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), + "score_temp", + segmentWriteState.context + ); + final String tempQuantizedVectorDataName = tempQuantizedVectorData.getName(); + final String tempScoreQuantizedVectorDataName = tempScoreQuantizedVectorData.getName(); + FloatVectorValues floatVectorValues = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(floatVectorValues); + } + DocsWithFieldSet docsWithField = writeBinarizedVectorAndQueryData( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + floatVectorValues, + centroid, + quantizer + ); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + binarizedDataInput = segmentWriteState.directory.openInput(tempQuantizedVectorData.getName(), segmentWriteState.context); + binarizedVectorData.copyBytes(binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + CodecUtil.retrieveChecksum(binarizedDataInput); + CodecUtil.writeFooter(tempScoreQuantizedVectorData); + IOUtils.close(tempScoreQuantizedVectorData); + binarizedScoreDataInput = segmentWriteState.directory.openInput( + tempScoreQuantizedVectorData.getName(), + segmentWriteState.context + ); + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + cDotC, + docsWithField + ); + success = true; + final IndexInput finalBinarizedDataInput = binarizedDataInput; + final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput; + OffHeapBinarizedVectorValues vectorValues = new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + centroid, + cDotC, + quantizer, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, + finalBinarizedDataInput + ); + RandomVectorScorerSupplier scorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues( + finalBinarizedScoreDataInput, + fieldInfo.getVectorDimension(), + docsWithField.cardinality() + ), + vectorValues + ); + return new ES818BinaryQuantizedVectorsWriter.BinarizedCloseableRandomVectorScorerSupplier(scorerSupplier, vectorValues, () -> { + IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorDataName, + tempScoreQuantizedVectorDataName + ); + }); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + binarizedDataInput, + binarizedScoreDataInput + ); + if (tempQuantizedVectorData != null) { + IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempQuantizedVectorData.getName()); + } + if (tempScoreQuantizedVectorData != null) { + IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempScoreQuantizedVectorData.getName()); + } + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, binarizedVectorData, rawVectorDelegate); + } + + static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof ES818BinaryQuantizedVectorsReader reader) { + return reader.getCentroid(fieldName); + } + return null; + } + + static int mergeAndRecalculateCentroids(MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { + boolean recalculate = false; + int totalVectorCount = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); + int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); + if (vectorCount == 0) { + continue; + } + totalVectorCount += vectorCount; + // If there aren't centroids, or previously clustered with more than one cluster + // or if there are deleted docs, we must recalculate the centroid + if (centroid == null || mergeState.liveDocs[i] != null) { + recalculate = true; + break; + } + for (int j = 0; j < centroid.length; j++) { + mergedCentroid[j] += centroid[j] * vectorCount; + } + } + if (recalculate) { + return calculateCentroid(mergeState, fieldInfo, mergedCentroid); + } else { + for (int j = 0; j < mergedCentroid.length; j++) { + mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(mergedCentroid); + } + return totalVectorCount; + } + } + + static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + // clear out the centroid + Arrays.fill(centroid, 0); + int count = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null) continue; + FloatVectorValues vectorValues = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { + ++count; + float[] vector = vectorValues.vectorValue(iterator.index()); + // TODO Panama sum + for (int j = 0; j < vector.length; j++) { + centroid[j] += vector[j]; + } + } + } + if (count == 0) { + return count; + } + // TODO Panama div + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= count; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(centroid); + } + return count; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private final float[] dimensionSums; + private final FloatArrayList magnitudes = new FloatArrayList(); + + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = flatFieldVectorsWriter; + this.dimensionSums = new float[fieldInfo.getVectorDimension()]; + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + public void normalizeVectors() { + for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { + float[] vector = flatFieldVectorsWriter.getVectors().get(i); + float magnitude = magnitudes.get(i); + for (int j = 0; j < vector.length; j++) { + vector[j] /= magnitude; + } + } + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float divisor = (float) Math.sqrt(dp); + magnitudes.add(divisor); + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } + } else { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += vectorValue[i]; + } + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + size += magnitudes.ramBytesUsed(); + return size; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java new file mode 100644 index 0000000000000..38b54d1520f9f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java @@ -0,0 +1,145 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +public class ES92HnswBinaryQuantizedBFloat16VectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "ES92HnswBinaryQuantizedBFloat16VectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private static final FlatVectorsFormat flatVectorsFormat = new ES92BinaryQuantizedBFloat16VectorsFormat(); + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public ES92HnswBinaryQuantizedBFloat16VectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public ES92HnswBinaryQuantizedBFloat16VectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public ES92HnswBinaryQuantizedBFloat16VectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES92HnswBinaryQuantizedBFloat16VectorsFormat(name=ES92HnswBinaryQuantizedBFloat16VectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java new file mode 100644 index 0000000000000..cd805a295a2f8 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java @@ -0,0 +1,311 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; + +import java.io.IOException; + +abstract class OffHeapBFloat16VectorValues extends FloatVectorValues implements HasIndexSlice { + + protected final int dimension; + protected final int size; + protected final IndexInput slice; + protected final int byteSize; + protected int lastOrd = -1; + protected final byte[] bfloatBytes; + protected final float[] value; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer flatVectorsScorer; + + OffHeapBFloat16VectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + this.dimension = dimension; + this.size = size; + this.slice = slice; + this.byteSize = byteSize; + this.similarityFunction = similarityFunction; + this.flatVectorsScorer = flatVectorsScorer; + bfloatBytes = new byte[dimension * 2]; + value = new float[dimension]; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return value; + } + slice.seek((long) targetOrd * byteSize); + // no readShorts() method + slice.readBytes(bfloatBytes, 0, bfloatBytes.length); + BFloat16.bFloat16ToFloat(bfloatBytes, value); + lastOrd = targetOrd; + return value; + } + + public static OffHeapBFloat16VectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, + OrdToDocDISIReaderConfiguration configuration, + VectorEncoding vectorEncoding, + int dimension, + int size, + long vectorDataOffset, + long vectorDataLength, + IndexInput vectorData + ) throws IOException { + if (configuration.isEmpty() || vectorEncoding != VectorEncoding.FLOAT32) { + return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); + } + IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); + int byteSize = dimension * BFloat16.BYTES; + if (configuration.isDense()) { + return new DenseOffHeapVectorValues(dimension, size, bytesSlice, byteSize, flatVectorsScorer, vectorSimilarityFunction); + } else { + return new SparseOffHeapVectorValues( + configuration, + vectorData, + bytesSlice, + dimension, + size, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction + ); + } + } + + /** + * Dense vector values that are stored off-heap. This is the most common case when every doc has a + * vector. + */ + public static class DenseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public int ordToDoc(int ord) { + return ord; + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.docID()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class SparseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + public SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + IndexInput dataIn, + IndexInput slice, + int dimension, + int size, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) throws IOException { + + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dataIn, + slice.clone(), + dimension, + size, + byteSize, + flatVectorsScorer, + similarityFunction + ); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + public EmptyOffHeapVectorValues(int dimension, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { + super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); + } + + @Override + public int dimension() { + return super.dimension(); + } + + @Override + public int size() { + return 0; + } + + @Override + public EmptyOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] vectorValue(int targetOrd) { + throw new UnsupportedOperationException(); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c9c14d027ebfd..5ac1a095dd1e0 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -58,6 +58,8 @@ import org.elasticsearch.index.codec.vectors.IVFVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -386,6 +388,7 @@ private DenseVectorIndexOptions defaultIndexOptions(boolean defaultInt8Hnsw, boo if (defaultBBQHnsw && dimIsConfigured && dims.getValue() >= BBQ_DIMS_DEFAULT_THRESHOLD) { return new BBQHnswIndexOptions( + 32, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, new RescoreVector(DEFAULT_OVERSAMPLE) @@ -1620,6 +1623,11 @@ public boolean supportsDimension(int dims) { BBQ_HNSW("bbq_hnsw", true) { @Override public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { + int rawVecSize = XContentMapValues.nodeIntegerValue(indexOptionsMap.remove("raw_vector_size"), 32); + if (rawVecSize != 32 && rawVecSize != 16) { + throw new IllegalArgumentException("Invalid raw vector size " + rawVecSize); + } + Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); if (mNode == null) { @@ -1638,7 +1646,7 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { + int rawVecSize = XContentMapValues.nodeIntegerValue(indexOptionsMap.remove("raw_vector_size"), 32); + if (rawVecSize != 32 && rawVecSize != 16) { + throw new IllegalArgumentException("Invalid raw vector size " + rawVecSize); + } + RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion); @@ -1662,7 +1675,7 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction); + case 16 -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(m, efConstruction); + default -> throw new AssertionError(); + }; } @Override @@ -2202,12 +2221,15 @@ public boolean updatableTo(DenseVectorIndexOptions update) { @Override boolean doEquals(DenseVectorIndexOptions other) { BBQHnswIndexOptions that = (BBQHnswIndexOptions) other; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(rescoreVector, that.rescoreVector); + return rawVectorSize == that.rawVectorSize + && m == that.m + && efConstruction == that.efConstruction + && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(m, efConstruction, rescoreVector); + return Objects.hash(rawVectorSize, m, efConstruction, rescoreVector); } @Override @@ -2219,6 +2241,9 @@ boolean isFlat() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("type", type); + if (rawVectorSize != 32) { + builder.field("raw_vector_size", rawVectorSize); + } builder.field("m", m); builder.field("ef_construction", efConstruction); if (rescoreVector != null) { @@ -2243,14 +2268,21 @@ public boolean validateDimension(int dim, boolean throwOnError) { static class BBQFlatIndexOptions extends QuantizedIndexOptions { private final int CLASS_NAME_HASH = this.getClass().getName().hashCode(); - BBQFlatIndexOptions(RescoreVector rescoreVector) { + private final int rawVectorSize; + + BBQFlatIndexOptions(int rawVectorSize, RescoreVector rescoreVector) { super(VectorIndexType.BBQ_FLAT, rescoreVector); + this.rawVectorSize = rawVectorSize; } @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new ES818BinaryQuantizedVectorsFormat(); + return switch (rawVectorSize) { + case 32 -> new ES818BinaryQuantizedVectorsFormat(); + case 16 -> new ES92BinaryQuantizedBFloat16VectorsFormat(); + default -> throw new AssertionError(); + }; } @Override @@ -2279,6 +2311,9 @@ boolean isFlat() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("type", type); + if (rawVectorSize != 32) { + builder.field("raw_vector_size", rawVectorSize); + } if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 14e68029abc3b..8ca8b2e3c3a7f 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -7,4 +7,6 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat +org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat org.elasticsearch.index.codec.vectors.IVFVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..c74d4506d0c58 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,349 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.*; +import org.apache.lucene.misc.store.DirectIODirectory; +import org.apache.lucene.search.*; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.CheckJoinIndex; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.store.*; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.store.MockDirectoryWrapper; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexModule; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.es818.BinarizedByteVectorValues; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.shard.ShardPath; +import org.elasticsearch.index.store.FsDirectoryFactory; +import org.elasticsearch.test.IndexSettingsModule; +import org.junit.Ignore; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; + +import static java.lang.String.format; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES92BinaryQuantizedBFloat16VectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES92BinaryQuantizedBFloat16VectorsFormat()); + + @Override + protected Codec getCodec() { + return codec; + } + + static String encodeInts(int[] i) { + return Arrays.toString(i); + } + + static BitSetProducer parentFilter(IndexReader r) throws IOException { + // Create a filter that defines "parent" documents in the index + BitSetProducer parentsFilter = new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent"))); + CheckJoinIndex.check(r, parentsFilter); + return parentsFilter; + } + + Document makeParent(int[] children) { + Document parent = new Document(); + parent.add(newStringField("docType", "_parent", Field.Store.NO)); + parent.add(newStringField("id", encodeInts(children), Field.Store.YES)); + return parent; + } + + public void testEmptyDiversifiedChildSearch() throws Exception { + String fieldName = "field"; + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + try (Directory d = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig().setCodec(codec); + iwc.setMergePolicy(new SoftDeletesRetentionMergePolicy("soft_delete", MatchAllDocsQuery::new, iwc.getMergePolicy())); + try (IndexWriter w = new IndexWriter(d, iwc)) { + List toAdd = new ArrayList<>(); + for (int j = 1; j <= 5; j++) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction)); + doc.add(newStringField("id", Integer.toString(j), Field.Store.YES)); + toAdd.add(doc); + } + toAdd.add(makeParent(new int[] { 1, 2, 3, 4, 5 })); + w.addDocuments(toAdd); + w.addDocuments(List.of(makeParent(new int[] { 6, 7, 8, 9, 10 }))); + w.deleteDocuments(new FieldExistsQuery(fieldName), new TermQuery(new Term("id", encodeInts(new int[] { 1, 2, 3, 4, 5 })))); + w.flush(); + w.commit(); + w.forceMerge(1); + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + BitSetProducer parentFilter = parentFilter(searcher.getIndexReader()); + Query query = new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vector, null, 1, parentFilter); + assertTrue(searcher.search(query, 1).scoreDocs.length == 0); + } + } + + } + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES92BinaryQuantizedBFloat16VectorsFormat(); + } + }; + String expectedPattern = "ES92BinaryQuantizedBFloat16VectorsFormat(" + + "name=ES92BinaryQuantizedBFloat16VectorsFormat, " + + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + @Override + @Ignore // bfloat16 makes the results slightly out of bounds + public void testWriterRamEstimate() throws Exception {} + + @Override + @Ignore // bfloat16 makes the results slightly out of bounds + public void testRandom() throws Exception {} + + @Override + @Ignore // bfloat16 makes the results slightly out of bounds + public void testVectorValuesReportCorrectDocs() throws Exception {} + + @Override + @Ignore // bfloat16 makes the results slightly out of bounds + public void testSparseVectors() throws Exception {} + + public void testQuantizedVectorsWriteAndRead() throws IOException { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + if (i % 101 == 0) { + w.commit(); + } + } + w.commit(); + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); + assertEquals(vectorValues.size(), numVectors); + BinarizedByteVectorValues qvectorValues = ((ES818BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues) + .getQuantizedVectorValues(); + float[] centroid = qvectorValues.getCentroid(); + assertEquals(centroid.length, dims); + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + int[] quantizedVector = new int[dims]; + byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8]; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); + } + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + OptimizedScalarQuantizer.QuantizationResult corrections = quantizer.scalarQuantize( + vectorValues.vectorValue(docIndexIterator.index()), + quantizedVector, + (byte) 1, + centroid + ); + BQVectorUtils.packAsBinary(quantizedVector, expectedVector); + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); + } + } + } + } + } + + public void testSimpleOffHeapSize() throws IOException { + try (Directory dir = newDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeFSDir() throws IOException { + checkDirectIOSupported(); + var config = newIndexWriterConfig().setUseCompoundFile(false); // avoid compound files to allow directIO + try (Directory dir = newFSDirectory()) { + testSimpleOffHeapSizeImpl(dir, config, false); + } + } + + public void testSimpleOffHeapSizeMMapDir() throws IOException { + try (Directory dir = newMMapDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, boolean expectVecOffHeap) throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (IndexWriter w = new IndexWriter(dir, config)) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertEquals(expectVecOffHeap ? 2 : 1, offHeap.size()); + assertTrue(offHeap.get("veb") > 0L); + if (expectVecOffHeap) { + assertEquals(vector.length * BFloat16.BYTES, (long) offHeap.get("vec")); + } + } + } + } + } + + static Directory newMMapDirectory() throws IOException { + Directory dir = new MMapDirectory(createTempDir("ES92BinaryQuantizedBFloat16VectorsFormatTests")); + if (random().nextBoolean()) { + dir = new MockDirectoryWrapper(random(), dir); + } + return dir; + } + + private Directory newFSDirectory() throws IOException { + Settings settings = Settings.builder() + .put(IndexModule.INDEX_STORE_TYPE_SETTING.getKey(), IndexModule.Type.HYBRIDFS.name().toLowerCase(Locale.ROOT)) + .build(); + IndexSettings idxSettings = IndexSettingsModule.newIndexSettings("foo", settings); + Path tempDir = createTempDir().resolve(idxSettings.getUUID()).resolve("0"); + Files.createDirectories(tempDir); + ShardPath path = new ShardPath(false, tempDir, tempDir, new ShardId(idxSettings.getIndex(), 0)); + Directory dir = (new FsDirectoryFactory()).newDirectory(idxSettings, path); + if (random().nextBoolean()) { + dir = new MockDirectoryWrapper(random(), dir); + } + return dir; + } + + static void checkDirectIOSupported() { + assumeTrue("Direct IO is not enabled", ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO); + + Path path = createTempDir("directIOProbe"); + try (Directory dir = open(path); IndexOutput out = dir.createOutput("out", IOContext.DEFAULT)) { + out.writeString("test"); + } catch (IOException e) { + assumeNoException("test requires a filesystem that supports Direct IO", e); + } + } + + static DirectIODirectory open(Path path) throws IOException { + return new DirectIODirectory(FSDirectory.open(path)) { + @Override + protected boolean useDirectIO(String name, IOContext context, OptionalLong fileLength) { + return true; + } + }; + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..12f56d2c4bb0f --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,252 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.*; +import org.apache.lucene.misc.store.DirectIODirectory; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.*; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.store.MockDirectoryWrapper; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.SameThreadExecutorService; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexModule; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.shard.ShardPath; +import org.elasticsearch.index.store.FsDirectoryFactory; +import org.elasticsearch.test.IndexSettingsModule; +import org.junit.Ignore; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Locale; +import java.util.OptionalLong; + +import static java.lang.String.format; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES92HnswBinaryQuantizedBFloat16VectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES92HnswBinaryQuantizedBFloat16VectorsFormat()); + + @Override + protected Codec getCodec() { + return codec; + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES92HnswBinaryQuantizedBFloat16VectorsFormat(10, 20, 1, null); + } + }; + String expectedPattern = + "ES92HnswBinaryQuantizedBFloat16VectorsFormat(name=ES92HnswBinaryQuantizedBFloat16VectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=ES92BinaryQuantizedBFloat16VectorsFormat(name=ES92BinaryQuantizedBFloat16VectorsFormat," + + " flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; + + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.01f); + } + float[] randomVector = randomVector(vector.length); + float trueScore = similarityFunction.compare(vector, randomVector); + TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + // When it's the only vector in a segment, the score should be very close to the true score + assertEquals(trueScore, td.scoreDocs[0].score, 0.01f); + } + } + } + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(-1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(0, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(20, 0)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(20, -1)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(512 + 1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 100, 1, new SameThreadExecutorService()) + ); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } + + @Override + @Ignore // bfloat16 makes the results slightly out of bounds + public void testWriterRamEstimate() throws Exception {} + + @Override + @Ignore // bfloat16 makes the results slightly out of bounds + public void testRandom() throws Exception {} + + @Override + @Ignore // bfloat16 makes the results slightly out of bounds + public void testRandomWithUpdatesAndGraph() throws Exception {} + + @Override + @Ignore // bfloat16 makes the results slightly out of bounds + public void testVectorValuesReportCorrectDocs() throws Exception {} + + @Override + @Ignore // bfloat16 makes the results slightly out of bounds + public void testSparseVectors() throws Exception {} + + public void testSimpleOffHeapSize() throws IOException { + try (Directory dir = newDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeFSDir() throws IOException { + checkDirectIOSupported(); + var config = newIndexWriterConfig().setUseCompoundFile(false); // avoid compound files to allow directIO + try (Directory dir = newFSDirectory()) { + testSimpleOffHeapSizeImpl(dir, config, false); + } + } + + public void testSimpleOffHeapSizeMMapDir() throws IOException { + try (Directory dir = newMMapDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, boolean expectVecOffHeap) throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (IndexWriter w = new IndexWriter(dir, config)) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertEquals(expectVecOffHeap ? 3 : 2, offHeap.size()); + assertEquals(1L, (long) offHeap.get("vex")); + assertTrue(offHeap.get("veb") > 0L); + if (expectVecOffHeap) { + assertEquals(vector.length * BFloat16.BYTES, (long) offHeap.get("vec")); + } + } + } + } + } + + static Directory newMMapDirectory() throws IOException { + Directory dir = new MMapDirectory(createTempDir("ES92HnswBinaryQuantizedBFloat16VectorsFormatTests")); + if (random().nextBoolean()) { + dir = new MockDirectoryWrapper(random(), dir); + } + return dir; + } + + private Directory newFSDirectory() throws IOException { + Settings settings = Settings.builder() + .put(IndexModule.INDEX_STORE_TYPE_SETTING.getKey(), IndexModule.Type.HYBRIDFS.name().toLowerCase(Locale.ROOT)) + .build(); + IndexSettings idxSettings = IndexSettingsModule.newIndexSettings("foo", settings); + Path tempDir = createTempDir().resolve(idxSettings.getUUID()).resolve("0"); + Files.createDirectories(tempDir); + ShardPath path = new ShardPath(false, tempDir, tempDir, new ShardId(idxSettings.getIndex(), 0)); + Directory dir = (new FsDirectoryFactory()).newDirectory(idxSettings, path); + if (random().nextBoolean()) { + dir = new MockDirectoryWrapper(random(), dir); + } + return dir; + } + + static void checkDirectIOSupported() { + assumeTrue("Direct IO is not enabled", ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO); + + Path path = createTempDir("directIOProbe"); + try (Directory dir = open(path); IndexOutput out = dir.createOutput("out", IOContext.DEFAULT)) { + out.writeString("test"); + } catch (IOException e) { + assumeNoException("test requires a filesystem that supports Direct IO", e); + } + } + + static DirectIODirectory open(Path path) throws IOException { + return new DirectIODirectory(FSDirectory.open(path)) { + @Override + protected boolean useDirectIO(String name, IOContext context, OptionalLong fileLength) { + return true; + } + }; + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 2524422ed8f90..5074fac785006 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -90,11 +90,15 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsA randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomFrom(16, 32), randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), - new DenseVectorFieldMapper.BBQFlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())) + new DenseVectorFieldMapper.BBQFlatIndexOptions( + randomFrom(16, 32), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ) ); } @@ -118,7 +122,12 @@ private DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsHnswQua randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), rescoreVector ), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), rescoreVector) + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomFrom(16, 32), + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + rescoreVector + ) ); } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 8e62e18cf02c1..06b05f6fa0f11 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -37,6 +37,8 @@ import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; @@ -215,6 +217,8 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th IndexWriterConfig iwc = new IndexWriterConfig(); // Pick codec from quantized vector formats to ensure scores use real scores when using knn rescore KnnVectorsFormat format = randomFrom( + new ES92BinaryQuantizedBFloat16VectorsFormat(), + new ES92HnswBinaryQuantizedBFloat16VectorsFormat(), new ES818BinaryQuantizedVectorsFormat(), new ES818HnswBinaryQuantizedVectorsFormat(), new ES813Int8FlatVectorFormat(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 0d260a557f602..aadbf94cfad11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1295,7 +1295,7 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDense int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(32, m, efConstruction, rescoreVector); } static SemanticTextIndexOptions defaultIndexOptions(IndexVersion indexVersionCreated, MinimalServiceSettings modelSettings) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 7ee178cbe2af6..74f18e9c45ab1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -1317,7 +1317,7 @@ private static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDens int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(32, m, efConstruction, rescoreVector); } private static SemanticTextIndexOptions defaultBbqHnswSemanticTextIndexOptions() { From b008aea2092e4751bfba9bcddc5f98ae509ccff3 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 7 Aug 2025 12:16:25 +0100 Subject: [PATCH 02/17] Add to KnnIndexTester --- .../elasticsearch/test/knn/CmdLineArgs.java | 11 ++++++ .../test/knn/KnnIndexTester.java | 14 ++++++- server/src/main/java/module-info.java | 1 + .../es92/ES92BFloat16FlatVectorsReader.java | 20 ++++++++-- .../es92/ES92BFloat16FlatVectorsWriter.java | 8 +++- .../es92/OffHeapBFloat16VectorValues.java | 10 ++--- ...ryQuantizedBFloat16VectorsFormatTests.java | 39 ++++++++++++++----- ...ryQuantizedBFloat16VectorsFormatTests.java | 23 +++++++---- 8 files changed, 97 insertions(+), 29 deletions(-) diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 85fa02aecaaef..887a71734d447 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -49,6 +49,7 @@ record CmdLineArgs( float filterSelectivity, long seed, VectorSimilarityFunction vectorSpace, + int rawVectorSize, int quantizeBits, VectorEncoding vectorEncoding, int dimensions, @@ -75,6 +76,7 @@ record CmdLineArgs( static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge"); static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space"); static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits"); + static final ParseField RAW_VECTOR_SIZE_FIELD = new ParseField("raw_vector_size"); static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding"); static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); @@ -108,6 +110,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD); PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD); PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD); + PARSER.declareInt(Builder::setRawVectorSize, RAW_VECTOR_SIZE_FIELD); PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD); PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); @@ -143,6 +146,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(REINDEX_FIELD.getPreferredName(), reindex); builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge); builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT)); + builder.field(RAW_VECTOR_SIZE_FIELD.getPreferredName(), rawVectorSize); builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT)); builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); @@ -176,6 +180,7 @@ static class Builder { private boolean reindex = false; private boolean forceMerge = false; private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN; + private int rawVectorSize = 32; private int quantizeBits = 8; private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; private int dimensions; @@ -278,6 +283,11 @@ public Builder setVectorSpace(String vectorSpace) { return this; } + public Builder setRawVectorSize(int rawVectorSize) { + this.rawVectorSize = rawVectorSize; + return this; + } + public Builder setQuantizeBits(int quantizeBits) { this.quantizeBits = quantizeBits; return this; @@ -343,6 +353,7 @@ public CmdLineArgs build() { filterSelectivity, seed, vectorSpace, + rawVectorSize, quantizeBits, vectorEncoding, dimensions, diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index b0695137dfef4..b36314ed3a5a0 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -30,6 +30,8 @@ import org.elasticsearch.index.codec.vectors.IVFVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat; import org.elasticsearch.logging.Level; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; @@ -105,9 +107,17 @@ static Codec createCodec(CmdLineArgs args) { } else { if (args.quantizeBits() == 1) { if (args.indexType() == IndexType.FLAT) { - format = new ES818BinaryQuantizedVectorsFormat(); + if (args.rawVectorSize() == 16) { + format = new ES92BinaryQuantizedBFloat16VectorsFormat(); + } else { + format = new ES818BinaryQuantizedVectorsFormat(); + } } else { - format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null); + if (args.rawVectorSize() == 16) { + format = new ES92HnswBinaryQuantizedBFloat16VectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null); + } else { + format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null); + } } } else if (args.quantizeBits() < 32) { if (args.indexType() == IndexType.FLAT) { diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index d0ac9a17b0779..21d255aece93f 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -487,5 +487,6 @@ exports org.elasticsearch.index.codec.perfield; exports org.elasticsearch.index.codec.vectors to org.elasticsearch.test.knn; exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn; + exports org.elasticsearch.index.codec.vectors.es92 to org.elasticsearch.test.knn; exports org.elasticsearch.inference.telemetry; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java index 6c4bae4beb2c3..dac2113ccd9ef 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java @@ -24,12 +24,26 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.index.*; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; -import org.apache.lucene.store.*; -import org.apache.lucene.util.IOUtils; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.core.IOUtils; import java.io.IOException; import java.io.UncheckedIOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java index ce0ad8387e24f..a3ef930574aa5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java @@ -43,11 +43,11 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.core.IOUtils; import java.io.Closeable; import java.io.IOException; @@ -292,7 +292,11 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldI } finally { if (success == false) { IOUtils.closeWhileHandlingException(vectorDataInput, tempVectorData); - IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempVectorData.getName()); + try { + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + } catch (Exception e) { + // ignore + } } } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java index cd805a295a2f8..f793c94d47a54 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java @@ -93,7 +93,7 @@ public float[] vectorValue(int targetOrd) throws IOException { return value; } - public static OffHeapBFloat16VectorValues load( + static OffHeapBFloat16VectorValues load( VectorSimilarityFunction vectorSimilarityFunction, FlatVectorsScorer flatVectorsScorer, OrdToDocDISIReaderConfiguration configuration, @@ -129,9 +129,9 @@ public static OffHeapBFloat16VectorValues load( * Dense vector values that are stored off-heap. This is the most common case when every doc has a * vector. */ - public static class DenseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + static class DenseOffHeapVectorValues extends OffHeapBFloat16VectorValues { - public DenseOffHeapVectorValues( + DenseOffHeapVectorValues( int dimension, int size, IndexInput slice, @@ -188,7 +188,7 @@ private static class SparseOffHeapVectorValues extends OffHeapBFloat16VectorValu private final IndexInput dataIn; private final OrdToDocDISIReaderConfiguration configuration; - public SparseOffHeapVectorValues( + SparseOffHeapVectorValues( OrdToDocDISIReaderConfiguration configuration, IndexInput dataIn, IndexInput slice, @@ -269,7 +269,7 @@ public DocIdSetIterator iterator() { private static class EmptyOffHeapVectorValues extends OffHeapBFloat16VectorValues { - public EmptyOffHeapVectorValues(int dimension, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { + EmptyOffHeapVectorValues(int dimension, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java index c74d4506d0c58..4bb9dcae1c116 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java @@ -27,14 +27,35 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.*; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.SoftDeletesRetentionMergePolicy; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.misc.store.DirectIODirectory; -import org.apache.lucene.search.*; +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.CheckJoinIndex; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.QueryBitSetProducer; -import org.apache.lucene.store.*; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.apache.lucene.tests.util.TestUtil; @@ -52,12 +73,15 @@ import org.elasticsearch.index.shard.ShardPath; import org.elasticsearch.index.store.FsDirectoryFactory; import org.elasticsearch.test.IndexSettingsModule; -import org.junit.Ignore; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.OptionalLong; import static java.lang.String.format; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; @@ -187,20 +211,17 @@ public void testSearchWithVisitedLimit() { // visited limit is not respected, as it is brute force search } + // bfloat16 makes the results of these tests slightly out of bounds @Override - @Ignore // bfloat16 makes the results slightly out of bounds public void testWriterRamEstimate() throws Exception {} @Override - @Ignore // bfloat16 makes the results slightly out of bounds public void testRandom() throws Exception {} @Override - @Ignore // bfloat16 makes the results slightly out of bounds public void testVectorValuesReportCorrectDocs() throws Exception {} @Override - @Ignore // bfloat16 makes the results slightly out of bounds public void testSparseVectors() throws Exception {} public void testQuantizedVectorsWriteAndRead() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java index 12f56d2c4bb0f..73d732e5fcc3e 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -27,10 +27,22 @@ import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.*; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.misc.store.DirectIODirectory; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.store.*; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.apache.lucene.tests.util.TestUtil; @@ -45,7 +57,6 @@ import org.elasticsearch.index.shard.ShardPath; import org.elasticsearch.index.store.FsDirectoryFactory; import org.elasticsearch.test.IndexSettingsModule; -import org.junit.Ignore; import java.io.IOException; import java.nio.file.Files; @@ -140,24 +151,20 @@ public void testVectorSimilarityFuncs() { assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); } + // bfloat16 makes the results of these tests slightly out of bounds @Override - @Ignore // bfloat16 makes the results slightly out of bounds public void testWriterRamEstimate() throws Exception {} @Override - @Ignore // bfloat16 makes the results slightly out of bounds public void testRandom() throws Exception {} @Override - @Ignore // bfloat16 makes the results slightly out of bounds public void testRandomWithUpdatesAndGraph() throws Exception {} @Override - @Ignore // bfloat16 makes the results slightly out of bounds public void testVectorValuesReportCorrectDocs() throws Exception {} @Override - @Ignore // bfloat16 makes the results slightly out of bounds public void testSparseVectors() throws Exception {} public void testSimpleOffHeapSize() throws IOException { From 370feb23c8c7496e8c96f9f528d294ca917133c7 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 7 Aug 2025 14:20:43 +0100 Subject: [PATCH 03/17] Add yaml tests --- .../41_knn_search_bbq_hnsw_bfloat16.yml | 610 ++++++++++++++++++ .../42_knn_search_bbq_flat_bfloat16.yml | 525 +++++++++++++++ 2 files changed, 1135 insertions(+) create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml new file mode 100644 index 0000000000000..501a9adb4227e --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -0,0 +1,610 @@ +setup: + - do: + indices.create: + index: bbq_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + raw_vector_size: 16 + + - do: + index: + index: bbq_hnsw + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "3" + body: + name: rabbit.jpg + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + indices.forcemerge: + index: bbq_hnsw + max_num_segments: 1 + + - do: + indices.refresh: { } +--- +"Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" + - do: + search: + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad quantization parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + element_type: byte + index: true + index_options: + type: bbq_hnsw + raw_vector_size: 16 + + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: false + index_options: + type: bbq_hnsw + raw_vector_size: 16 +--- +"Test bad raw vector size": + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + element_type: byte + index: true + index_options: + type: bbq_hnsw + raw_vector_size: 25 +--- +"Test few dimensions fail indexing": + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + dims: 42 + index: true + index_options: + type: bbq_hnsw + raw_vector_size: 16 + + - do: + indices.create: + index: dynamic_dim_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + index: true + similarity: l2_norm + index_options: + type: bbq_hnsw + raw_vector_size: 16 + + - do: + catch: bad_request + index: + index: dynamic_dim_bbq_hnsw + body: + vector: [1.0, 2.0, 3.0, 4.0, 5.0] + + - do: + index: + index: dynamic_dim_bbq_hnsw + body: + vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + raw_vector_size: 16 + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } +--- +"Test index configured rescore vector updateable and settable to 0": + - requires: + cluster_features: ["mapper.dense_vector.rescore_zero_vector"] + reason: Needs rescore_zero_vector feature + + - do: + indices.create: + index: bbq_rescore_0_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + index_options: + type: bbq_hnsw + raw_vector_size: 16 + rescore_vector: + oversample: 0 + + - do: + indices.create: + index: bbq_rescore_update_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + index_options: + type: bbq_hnsw + raw_vector_size: 16 + rescore_vector: + oversample: 1 + + - do: + indices.put_mapping: + index: bbq_rescore_update_hnsw + body: + properties: + vector: + type: dense_vector + index_options: + type: bbq_hnsw + raw_vector_size: 16 + rescore_vector: + oversample: 0 + + - do: + indices.get_mapping: + index: bbq_rescore_update_hnsw + + - match: { .bbq_rescore_update_hnsw.mappings.properties.vector.index_options.rescore_vector.oversample: 0 } +--- +"Test index configured rescore vector score consistency": + - requires: + cluster_features: ["mapper.dense_vector.rescore_zero_vector"] + reason: Needs rescore_zero_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_zero_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + raw_vector_size: 16 + rescore_vector: + oversample: 0 + + - do: + bulk: + index: bbq_rescore_zero_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: raw_score0 } + - set: { hits.hits.1._score: raw_score1 } + - set: { hits.hits.2._score: raw_score2 } + + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 2 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: override_score0 } + - set: { hits.hits.1._score: override_score1 } + - set: { hits.hits.2._score: override_score2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_hnsw + body: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + raw_vector_size: 16 + rescore_vector: + oversample: 2 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: default_rescore0 } + - set: { hits.hits.1._score: default_rescore1 } + - set: { hits.hits.2._score: default_rescore2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_hnsw + body: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + raw_vector_size: 16 + rescore_vector: + oversample: 0 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $override_score0 } + - match: { hits.hits.0._score: $default_rescore0 } + - match: { hits.hits.1._score: $override_score1 } + - match: { hits.hits.1._score: $default_rescore1 } + - match: { hits.hits.2._score: $override_score2 } + - match: { hits.hits.2._score: $default_rescore2 } + +--- +"default oversample value": + - requires: + cluster_features: ["mapper.dense_vector.default_oversample_value_for_bbq"] + reason: "Needs default_oversample_value_for_bbq feature" + - do: + indices.get_mapping: + index: bbq_hnsw + + - match: { bbq_hnsw.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml new file mode 100644 index 0000000000000..f0a7ec15f6952 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -0,0 +1,525 @@ +setup: + - do: + indices.create: + index: bbq_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + raw_vector_size: 16 + + - do: + index: + index: bbq_flat + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + index: + index: bbq_flat + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + index: + index: bbq_flat + id: "3" + body: + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + indices.forcemerge: + index: bbq_flat + max_num_segments: 1 +--- +"Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" + - do: + search: + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + query: + script_score: + query: { match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + index_options: + type: bbq_flat + raw_vector_size: 16 + m: 42 + + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + element_type: byte + index: true + index_options: + type: bbq_flat + raw_vector_size: 16 +--- +"Test bad raw vector size": + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + index_options: + type: bbq_flat + raw_vector_size: 25 +--- +"Test few dimensions fail indexing": + # verify index creation fails + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + dims: 42 + index: true + similarity: l2_norm + index_options: + type: bbq_flat + raw_vector_size: 16 + + # verify dynamic dimension fails + - do: + indices.create: + index: dynamic_dim_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + index: true + similarity: l2_norm + index_options: + type: bbq_flat + raw_vector_size: 16 + + # verify index fails for odd dim vector + - do: + catch: bad_request + index: + index: dynamic_dim_bbq_flat + body: + vector: [1.0, 2.0, 3.0, 4.0, 5.0] + + # verify that we can index an even dim vector after the odd dim vector failure + - do: + index: + index: dynamic_dim_bbq_flat + body: + vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + raw_vector_size: 16 + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"default oversample value": + - requires: + cluster_features: ["mapper.dense_vector.default_oversample_value_for_bbq"] + reason: "Needs default_oversample_value_for_bbq feature" + - do: + indices.get_mapping: + index: bbq_flat + + - match: { bbq_flat.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } +--- +"Test nested queries": + - do: + indices.create: + index: bbq_flat_nested + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + name: + type: keyword + nested: + type: nested + properties: + paragraph_id: + type: keyword + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + raw_vector_size: 16 + + - do: + index: + index: bbq_flat_nested + id: "1" + body: + nested: + - paragraph_id: "1" + vector: [ 0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45 ] + - paragraph_id: "2" + vector: [ 0.7, 0.2 , 0.205, 0.63 , 0.032, 0.201, 0.167, 0.313, + 0.176, 0.1, 0.375, 0.334, 0.046, 0.078, 0.349, 0.272, + 0.307, 0.083, 0.504, 0.255, 0.404, 0.289, 0.226, 0.132, + 0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , 0.265, + 0.285, 0.336, 0.272, 0.369, -0.282, 0.086, 0.132, 0.475, + 0.224, 0.203, 0.439, 0.064, 0.246, 0.396, 0.297, 0.242, + 0.224, 0.203, 0.439, 0.064, 0.246, 0.396, 0.297, 0.242, + 0.028, 0.321, 0.022, 0.009, 0.001 , 0.031, -0.533, 0.45] + - do: + index: + index: bbq_flat_nested + id: "2" + body: + nested: + - paragraph_id: 0 + vector: [ 0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27, -0.013 ] + - paragraph_id: 2 + vector: [ 0.196, 0.514, 0.039, 0.555, 0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, 0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, 0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, 0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, 0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27, 0.013 ] + - paragraph_id: 3 + vector: [ 0.196, 0.514, 0.039, 0.555, 0.042, 0.242, 0.463, -0.348, + 0.08 , 0.442, -0.067, -0.05 , 0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, 0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, 0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, 0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, 0.209, -0.153, -0.27, -0.013 ] + + - do: + index: + index: bbq_flat_nested + id: "3" + body: + nested: + - paragraph_id: 0 + vector: [ 0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058 ] + + - do: + indices.flush: + index: bbq_flat_nested + + - do: + indices.forcemerge: + index: bbq_flat_nested + max_num_segments: 1 + + - do: + search: + index: bbq_flat_nested + body: + query: + nested: + path: nested + query: + knn: + field: nested.vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + num_candidates: 3 + k: 2 + + - match: {hits.hits.0._id: "3"} + + - do: + search: + index: bbq_flat_nested + body: + knn: + field: nested.vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + num_candidates: 3 + k: 2 + + - match: {hits.hits.0._id: "3"} From 522c8aef9d83c54b8dd8dfda3085d3a7572d4c08 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 7 Aug 2025 15:37:29 +0100 Subject: [PATCH 04/17] Add test feature for bfloat16 --- .../test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml | 3 +++ .../test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml | 3 +++ .../java/org/elasticsearch/index/mapper/MapperFeatures.java | 4 +++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml index 501a9adb4227e..9c387da6e0807 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -1,4 +1,7 @@ setup: + - requires: + cluster_features: "mapper.bbq_bfloat16" + reason: 'bfloat16 needs to be supported' - do: indices.create: index: bbq_hnsw diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml index f0a7ec15f6952..46c85cc89ac5d 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -1,4 +1,7 @@ setup: + - requires: + cluster_features: "mapper.bbq_bfloat16" + reason: 'bfloat16 needs to be supported' - do: indices.create: index: bbq_flat diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 7ba2dfb9a69f5..2480a7a373bef 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -47,6 +47,7 @@ public class MapperFeatures implements FeatureSpecification { static final NodeFeature BBQ_DISK_SUPPORT = new NodeFeature("mapper.bbq_disk_support"); static final NodeFeature SEARCH_LOAD_PER_SHARD = new NodeFeature("mapper.search_load_per_shard"); static final NodeFeature PATTERNED_TEXT = new NodeFeature("mapper.patterned_text"); + static final NodeFeature BBQ_BFLOAT16 = new NodeFeature("mapper.bbq_bfloat16"); @Override public Set getTestFeatures() { @@ -80,7 +81,8 @@ public Set getTestFeatures() { BBQ_DISK_SUPPORT, SEARCH_LOAD_PER_SHARD, SPARSE_VECTOR_INDEX_OPTIONS_FEATURE, - PATTERNED_TEXT + PATTERNED_TEXT, + BBQ_BFLOAT16 ); } } From f979e891f4659a98f5ad36136ceda28c7d272dba Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 7 Aug 2025 16:51:10 +0100 Subject: [PATCH 05/17] Define a new element type rather than a special option --- .../41_knn_search_bbq_hnsw_bfloat16.yml | 55 +--- .../42_knn_search_bbq_flat_bfloat16.yml | 28 +- .../vectors/DenseVectorFieldMapper.java | 247 ++++++++++++++---- .../mapper/vectors/VectorDVLeafFieldData.java | 6 +- .../script/VectorScoreScriptUtils.java | 8 +- .../BFloat16RankVectorsDocValuesField.java | 158 +++++++++++ .../DenseVectorFieldMapperTestUtils.java | 6 +- .../vectors/DenseVectorFieldMapperTests.java | 4 +- .../vectors/DenseVectorFieldTypeTests.java | 13 +- ...AbstractKnnVectorQueryBuilderTestCase.java | 4 +- .../TestDenseInferenceServiceExtension.java | 2 +- .../mapper/SemanticTextFieldMapper.java | 2 +- .../elastic/ElasticTextEmbeddingPayload.java | 2 +- ...cInferenceMetadataFieldsRecoveryTests.java | 2 +- .../mapper/SemanticTextFieldMapperTests.java | 2 +- .../mapper/SemanticTextFieldTests.java | 8 +- .../queries/SemanticQueryBuilderTests.java | 2 +- .../mapper/RankVectorsDVLeafFieldData.java | 44 +++- .../script/RankVectorsScoreScriptUtils.java | 2 +- .../mapper/RankVectorsFieldMapperTests.java | 4 +- .../RankVectorsScriptDocValuesTests.java | 95 ++++++- 21 files changed, 535 insertions(+), 159 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml index 9c387da6e0807..700d3cc33b51e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -13,12 +13,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true similarity: max_inner_product index_options: type: bbq_hnsw - raw_vector_size: 16 - do: index: @@ -191,44 +191,11 @@ setup: properties: vector: type: dense_vector - dims: 64 - element_type: byte - index: true - index_options: - type: bbq_hnsw - raw_vector_size: 16 - - - do: - catch: bad_request - indices.create: - index: bad_bbq_hnsw - body: - mappings: - properties: - vector: - type: dense_vector + element_type: bfloat16 dims: 64 index: false index_options: type: bbq_hnsw - raw_vector_size: 16 ---- -"Test bad raw vector size": - - do: - catch: bad_request - indices.create: - index: bad_bbq_hnsw - body: - mappings: - properties: - vector: - type: dense_vector - dims: 64 - element_type: byte - index: true - index_options: - type: bbq_hnsw - raw_vector_size: 25 --- "Test few dimensions fail indexing": - do: @@ -240,11 +207,11 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 42 index: true index_options: type: bbq_hnsw - raw_vector_size: 16 - do: indices.create: @@ -254,11 +221,11 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 index: true similarity: l2_norm index_options: type: bbq_hnsw - raw_vector_size: 16 - do: catch: bad_request @@ -290,12 +257,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true similarity: max_inner_product index_options: type: bbq_hnsw - raw_vector_size: 16 rescore_vector: oversample: 1.5 @@ -380,9 +347,9 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 index_options: type: bbq_hnsw - raw_vector_size: 16 rescore_vector: oversample: 0 @@ -397,9 +364,9 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 index_options: type: bbq_hnsw - raw_vector_size: 16 rescore_vector: oversample: 1 @@ -410,9 +377,9 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 index_options: type: bbq_hnsw - raw_vector_size: 16 rescore_vector: oversample: 0 @@ -439,12 +406,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true similarity: max_inner_product index_options: type: bbq_hnsw - raw_vector_size: 16 rescore_vector: oversample: 0 @@ -520,12 +487,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true similarity: max_inner_product index_options: type: bbq_hnsw - raw_vector_size: 16 rescore_vector: oversample: 2 @@ -561,12 +528,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true similarity: max_inner_product index_options: type: bbq_hnsw - raw_vector_size: 16 rescore_vector: oversample: 0 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml index 46c85cc89ac5d..82f059b05eb09 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -13,12 +13,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true similarity: max_inner_product index_options: type: bbq_flat - raw_vector_size: 16 - do: index: @@ -187,28 +187,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true index_options: type: bbq_flat - raw_vector_size: 16 m: 42 - - - do: - catch: bad_request - indices.create: - index: bad_bbq_flat - body: - mappings: - properties: - vector: - type: dense_vector - dims: 64 - element_type: byte - index: true - index_options: - type: bbq_flat - raw_vector_size: 16 --- "Test bad raw vector size": - do: @@ -237,12 +221,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 42 index: true similarity: l2_norm index_options: type: bbq_flat - raw_vector_size: 16 # verify dynamic dimension fails - do: @@ -253,11 +237,11 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 index: true similarity: l2_norm index_options: type: bbq_flat - raw_vector_size: 16 # verify index fails for odd dim vector - do: @@ -291,12 +275,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true similarity: max_inner_product index_options: type: bbq_flat - raw_vector_size: 16 rescore_vector: oversample: 1.5 @@ -395,12 +379,12 @@ setup: type: keyword vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true similarity: max_inner_product index_options: type: bbq_flat - raw_vector_size: 16 - do: index: diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 5ac1a095dd1e0..b63a023140e58 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -388,7 +388,6 @@ private DenseVectorIndexOptions defaultIndexOptions(boolean defaultInt8Hnsw, boo if (defaultBBQHnsw && dimIsConfigured && dims.getValue() >= BBQ_DIMS_DEFAULT_THRESHOLD) { return new BBQHnswIndexOptions( - 32, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, new RescoreVector(DEFAULT_OVERSAMPLE) @@ -1123,6 +1122,182 @@ public void checkDimensions(Integer dvDims, int qvDims) { ); } } + }, + + BFLOAT16 { + + @Override + public String toString() { + return "bfloat16"; + } + + @Override + public void writeValue(ByteBuffer byteBuffer, float value) { + byteBuffer.putShort((short) (Float.floatToIntBits(value) >>> 16)); + } + + @Override + public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException { + b.value(Float.intBitsToFloat(byteBuffer.getShort() << 16)); + } + + private KnnFloatVectorField createKnnVectorField(String name, float[] vector, VectorSimilarityFunction function) { + if (vector == null) { + throw new IllegalArgumentException("vector value must not be null"); + } + FieldType denseVectorFieldType = new FieldType(); + denseVectorFieldType.setVectorAttributes(vector.length, VectorEncoding.FLOAT32, function); + denseVectorFieldType.freeze(); + return new KnnFloatVectorField(name, vector, denseVectorFieldType); + } + + @Override + IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldType, FieldDataContext fieldDataContext) { + return new VectorIndexFieldData.Builder( + denseVectorFieldType.name(), + CoreValuesSourceType.KEYWORD, + denseVectorFieldType.indexVersionCreated, + this, + denseVectorFieldType.dims, + denseVectorFieldType.indexed, + denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE) + && denseVectorFieldType.indexed + && denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) { + @Override + public CacheHelper getCoreCacheHelper() { + return r.getCoreCacheHelper(); + } + + @Override + public CacheHelper getReaderCacheHelper() { + return r.getReaderCacheHelper(); + } + + @Override + public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { + FloatVectorValues values = in.getFloatVectorValues(fieldName); + if (values == null) { + return null; + } + return new DenormalizedCosineFloatVectorValues( + values, + in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); + } + } : r -> r + ); + } + + @Override + StringBuilder checkVectorErrors(float[] vector) { + return checkNanAndInfinite(vector); + } + + @Override + void checkVectorMagnitude( + VectorSimilarity similarity, + Function appender, + float squaredMagnitude + ) { + StringBuilder errorBuilder = null; + + if (Float.isNaN(squaredMagnitude) || Float.isInfinite(squaredMagnitude)) { + errorBuilder = new StringBuilder( + "NaN or Infinite magnitude detected, this usually means the vector values are too extreme to fit within a float." + ); + } + if (errorBuilder != null) { + throw new IllegalArgumentException(appender.apply(errorBuilder).toString()); + } + + if (similarity == VectorSimilarity.DOT_PRODUCT && isNotUnitVector(squaredMagnitude)) { + errorBuilder = new StringBuilder( + "The [" + VectorSimilarity.DOT_PRODUCT + "] similarity can only be used with unit-length vectors." + ); + } else if (similarity == VectorSimilarity.COSINE && Math.sqrt(squaredMagnitude) == 0.0f) { + errorBuilder = new StringBuilder( + "The [" + VectorSimilarity.COSINE + "] similarity does not support vectors with zero magnitude." + ); + } + + if (errorBuilder != null) { + throw new IllegalArgumentException(appender.apply(errorBuilder).toString()); + } + } + + @Override + public double computeSquaredMagnitude(VectorData vectorData) { + return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector()); + } + + @Override + public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + int index = 0; + float[] vector = new float[fieldMapper.fieldType().dims]; + float squaredMagnitude = 0; + for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { + fieldMapper.checkDimensionExceeded(index, context); + ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); + + float value = context.parser().floatValue(true); + vector[index++] = value; + squaredMagnitude += value * value; + } + fieldMapper.checkDimensionMatches(index, context); + checkVectorBounds(vector); + checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude); + if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE) + && fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE) + && isNotUnitVector(squaredMagnitude)) { + float length = (float) Math.sqrt(squaredMagnitude); + for (int i = 0; i < vector.length; i++) { + vector[i] /= length; + } + final String fieldName = fieldMapper.fieldType().name() + COSINE_MAGNITUDE_FIELD_SUFFIX; + Field magnitudeField = new FloatDocValuesField(fieldName, length); + context.doc().addWithKey(fieldName, magnitudeField); + } + Field field = createKnnVectorField( + fieldMapper.fieldType().name(), + vector, + fieldMapper.fieldType().similarity.vectorSimilarityFunction(fieldMapper.indexCreatedVersion, this) + ); + context.doc().addWithKey(fieldMapper.fieldType().name(), field); + } + + @Override + public VectorData parseKnnVector( + DocumentParserContext context, + int dims, + IntBooleanConsumer dimChecker, + VectorSimilarity similarity + ) throws IOException { + int index = 0; + float squaredMagnitude = 0; + float[] vector = new float[dims]; + for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { + dimChecker.accept(index, false); + ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); + float value = context.parser().floatValue(true); + vector[index] = value; + squaredMagnitude += value * value; + index++; + } + dimChecker.accept(index, true); + checkVectorBounds(vector); + checkVectorMagnitude(similarity, errorFloatElementsAppender(vector), squaredMagnitude); + return VectorData.fromFloats(vector); + } + + @Override + public int getNumBytes(int dimensions) { + return dimensions * Short.BYTES; + } + + @Override + public ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) { + return ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN); + } }; public abstract void writeValue(ByteBuffer byteBuffer, float value); @@ -1288,7 +1463,9 @@ public static ElementType fromString(String name) { ElementType.FLOAT.toString(), ElementType.FLOAT, ElementType.BIT.toString(), - ElementType.BIT + ElementType.BIT, + ElementType.BFLOAT16.toString(), + ElementType.BFLOAT16 ); public enum VectorSimilarity { @@ -1296,7 +1473,7 @@ public enum VectorSimilarity { @Override float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { - case BYTE, FLOAT -> 1f / (1f + similarity * similarity); + case BYTE, FLOAT, BFLOAT16 -> 1f / (1f + similarity * similarity); case BIT -> (dim - similarity) / dim; }; } @@ -1311,7 +1488,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi float score(float similarity, ElementType elementType, int dim) { assert elementType != ElementType.BIT; return switch (elementType) { - case BYTE, FLOAT -> (1 + similarity) / 2f; + case BYTE, FLOAT, BFLOAT16 -> (1 + similarity) / 2f; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1328,7 +1505,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15)); - case FLOAT -> (1 + similarity) / 2f; + case FLOAT, BFLOAT16 -> (1 + similarity) / 2f; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1342,7 +1519,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi @Override float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { - case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; + case BYTE, FLOAT, BFLOAT16 -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1623,11 +1800,6 @@ public boolean supportsDimension(int dims) { BBQ_HNSW("bbq_hnsw", true) { @Override public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { - int rawVecSize = XContentMapValues.nodeIntegerValue(indexOptionsMap.remove("raw_vector_size"), 32); - if (rawVecSize != 32 && rawVecSize != 16) { - throw new IllegalArgumentException("Invalid raw vector size " + rawVecSize); - } - Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); if (mNode == null) { @@ -1646,12 +1818,12 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { - int rawVecSize = XContentMapValues.nodeIntegerValue(indexOptionsMap.remove("raw_vector_size"), 32); - if (rawVecSize != 32 && rawVecSize != 16) { - throw new IllegalArgumentException("Invalid raw vector size " + rawVecSize); - } - RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion); @@ -1675,12 +1842,12 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction); - case 16 -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(m, efConstruction); + return switch (elementType) { + case FLOAT -> new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction); + case BFLOAT16 -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(m, efConstruction); default -> throw new AssertionError(); }; } @@ -2221,15 +2385,12 @@ public boolean updatableTo(DenseVectorIndexOptions update) { @Override boolean doEquals(DenseVectorIndexOptions other) { BBQHnswIndexOptions that = (BBQHnswIndexOptions) other; - return rawVectorSize == that.rawVectorSize - && m == that.m - && efConstruction == that.efConstruction - && Objects.equals(rescoreVector, that.rescoreVector); + return m == that.m && efConstruction == that.efConstruction && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(rawVectorSize, m, efConstruction, rescoreVector); + return Objects.hash(m, efConstruction, rescoreVector); } @Override @@ -2241,9 +2402,6 @@ boolean isFlat() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("type", type); - if (rawVectorSize != 32) { - builder.field("raw_vector_size", rawVectorSize); - } builder.field("m", m); builder.field("ef_construction", efConstruction); if (rescoreVector != null) { @@ -2268,19 +2426,15 @@ public boolean validateDimension(int dim, boolean throwOnError) { static class BBQFlatIndexOptions extends QuantizedIndexOptions { private final int CLASS_NAME_HASH = this.getClass().getName().hashCode(); - private final int rawVectorSize; - - BBQFlatIndexOptions(int rawVectorSize, RescoreVector rescoreVector) { + BBQFlatIndexOptions(RescoreVector rescoreVector) { super(VectorIndexType.BBQ_FLAT, rescoreVector); - this.rawVectorSize = rawVectorSize; } @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - assert elementType == ElementType.FLOAT; - return switch (rawVectorSize) { - case 32 -> new ES818BinaryQuantizedVectorsFormat(); - case 16 -> new ES92BinaryQuantizedBFloat16VectorsFormat(); + return switch (elementType) { + case FLOAT -> new ES818BinaryQuantizedVectorsFormat(); + case BFLOAT16 -> new ES92BinaryQuantizedBFloat16VectorsFormat(); default -> throw new AssertionError(); }; } @@ -2311,9 +2465,6 @@ boolean isFlat() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("type", type); - if (rawVectorSize != 32) { - builder.field("raw_vector_size", rawVectorSize); - } if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } @@ -2517,7 +2668,7 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) } Query knnQuery = switch (elementType) { case BYTE -> createExactKnnByteQuery(queryVector.asByteVector()); - case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector()); + case FLOAT, BFLOAT16 -> createExactKnnFloatQuery(queryVector.asFloatVector()); case BIT -> createExactKnnBitQuery(queryVector.asByteVector()); }; if (vectorSimilarity != null) { @@ -2590,7 +2741,7 @@ public Query createKnnQuery( knnSearchStrategy, hnswEarlyTermination ); - case FLOAT -> createKnnFloatQuery( + case FLOAT, BFLOAT16 -> createKnnFloatQuery( queryVector.asFloatVector(), k, numCands, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java index e44202d353629..fc99438c6d076 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java @@ -69,14 +69,14 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { if (indexed) { return switch (elementType) { case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); - case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); + case FLOAT, BFLOAT16 -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); }; } else { BinaryDocValues values = DocValues.getBinary(reader, field); return switch (elementType) { case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims); - case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); + case FLOAT, BFLOAT16 -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims); }; } @@ -138,7 +138,7 @@ public Object nextValue() { return vectorValue; } }; - case FLOAT -> new FormattedDocValues() { + case FLOAT, BFLOAT16 -> new FormattedDocValues() { float[] vector = new float[dims]; private FloatVectorValues floatVectorValues; // use when indexed private KnnVectorValues.DocIndexIterator iterator; // use when indexed diff --git a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java index d319c8d8b40aa..878cb42bdcce0 100644 --- a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java +++ b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java @@ -209,7 +209,7 @@ public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatL1Norm(scoreScript, field, (List) queryVector); } @@ -319,7 +319,7 @@ public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatL2Norm(scoreScript, field, (List) queryVector); } @@ -477,7 +477,7 @@ public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatDotProduct(scoreScript, field, (List) queryVector); } @@ -546,7 +546,7 @@ public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fiel } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatCosineSimilarity(scoreScript, field, (List) queryVector); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java new file mode 100644 index 0000000000000..21326faf1b6ab --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ShortBuffer; +import java.util.Iterator; + +public class BFloat16RankVectorsDocValuesField extends RankVectorsDocValuesField { + + private final BinaryDocValues input; + private final BinaryDocValues magnitudes; + private boolean decoded; + private final int dims; + private BytesRef value; + private BytesRef magnitudesValue; + private BFloat16VectorIterator vectorValues; + private int numVectors; + private float[] buffer; + + public BFloat16RankVectorsDocValuesField( + BinaryDocValues input, + BinaryDocValues magnitudes, + String name, + ElementType elementType, + int dims + ) { + super(name, elementType); + this.input = input; + this.magnitudes = magnitudes; + this.dims = dims; + this.buffer = new float[dims]; + } + + @Override + public void setNextDocId(int docId) throws IOException { + decoded = false; + if (input.advanceExact(docId)) { + boolean magnitudesFound = magnitudes.advanceExact(docId); + assert magnitudesFound; + + value = input.binaryValue(); + assert value.length % (Short.BYTES * dims) == 0; + numVectors = value.length / (Short.BYTES * dims); + magnitudesValue = magnitudes.binaryValue(); + assert magnitudesValue.length == (Float.BYTES * numVectors); + } else { + value = null; + magnitudesValue = null; + numVectors = 0; + } + } + + @Override + public RankVectorsScriptDocValues toScriptDocValues() { + return new RankVectorsScriptDocValues(this, dims); + } + + @Override + public boolean isEmpty() { + return value == null; + } + + @Override + public RankVectors get() { + if (isEmpty()) { + return RankVectors.EMPTY; + } + decodeVectorIfNecessary(); + return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); + } + + @Override + public RankVectors get(RankVectors defaultValue) { + if (isEmpty()) { + return defaultValue; + } + decodeVectorIfNecessary(); + return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); + } + + @Override + public RankVectors getInternal() { + return get(null); + } + + @Override + public int size() { + return value == null ? 0 : value.length / (Short.BYTES * dims); + } + + private void decodeVectorIfNecessary() { + if (decoded == false && value != null) { + vectorValues = new BFloat16VectorIterator(value, buffer, numVectors); + decoded = true; + } + } + + public static class BFloat16VectorIterator implements VectorIterator { + private final float[] buffer; + private final ShortBuffer vectorValues; + private final BytesRef vectorValueBytesRef; + private final int size; + private int idx = 0; + + public BFloat16VectorIterator(BytesRef vectorValues, float[] buffer, int size) { + assert vectorValues.length == (buffer.length * Short.BYTES * size); + this.vectorValueBytesRef = vectorValues; + this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length) + .order(ByteOrder.LITTLE_ENDIAN) + .asShortBuffer(); + this.size = size; + this.buffer = buffer; + } + + @Override + public boolean hasNext() { + return idx < size; + } + + @Override + public float[] next() { + if (hasNext() == false) { + throw new IllegalArgumentException("No more elements in the iterator"); + } + for (int i = 0; i < buffer.length; i++) { + buffer[i] = Float.intBitsToFloat(vectorValues.get() << 16); + } + idx++; + return buffer; + } + + @Override + public Iterator copy() { + return new BFloat16VectorIterator(vectorValueBytesRef, new float[buffer.length], size); + } + + @Override + public void reset() { + idx = 0; + vectorValues.rewind(); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java index 9478508da88d0..c3e3ed0c93c98 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java @@ -22,14 +22,14 @@ private DenseVectorFieldMapperTestUtils() {} public static List getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) { return switch (elementType) { - case FLOAT, BYTE -> List.of(SimilarityMeasure.values()); + case FLOAT, BFLOAT16, BYTE -> List.of(SimilarityMeasure.values()); case BIT -> List.of(SimilarityMeasure.L2_NORM); }; } public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { return switch (elementType) { - case FLOAT, BYTE -> dimensions; + case FLOAT, BFLOAT16, BYTE -> dimensions; case BIT -> { assert dimensions % Byte.SIZE == 0; yield dimensions / Byte.SIZE; @@ -43,7 +43,7 @@ public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType } return switch (elementType) { - case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); + case FLOAT, BFLOAT16, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); case BIT -> { if (max < 8) { throw new IllegalArgumentException("max must be at least 8 for bit vectors"); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index a5c5a9a9b42ef..dbea92b0b5e83 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -2462,7 +2462,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) ft; return switch (vectorFieldType.getElementType()) { case BYTE -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions()); - case FLOAT -> { + case FLOAT, BFLOAT16 -> { float[] floats = new float[vectorFieldType.getVectorDimensions()]; float magnitude = 0; for (int i = 0; i < floats.length; i++) { @@ -3063,7 +3063,7 @@ public SyntheticSourceExample example(int maxValues) throws IOException { Object value = switch (elementType) { case BYTE, BIT: yield randomList(dims, dims, ESTestCase::randomByte); - case FLOAT: + case FLOAT, BFLOAT16: yield randomList(dims, dims, ESTestCase::randomFloat); }; return new SyntheticSourceExample(value, value, this::mapping); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 5074fac785006..2524422ed8f90 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -90,15 +90,11 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsA randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.BBQHnswIndexOptions( - randomFrom(16, 32), randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), - new DenseVectorFieldMapper.BBQFlatIndexOptions( - randomFrom(16, 32), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) - ) + new DenseVectorFieldMapper.BBQFlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())) ); } @@ -122,12 +118,7 @@ private DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsHnswQua randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), rescoreVector ), - new DenseVectorFieldMapper.BBQHnswIndexOptions( - randomFrom(16, 32), - randomIntBetween(1, 100), - randomIntBetween(1, 10_000), - rescoreVector - ) + new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), rescoreVector) ); } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index a8d9b1259cb41..cf85e8cea49b5 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -243,7 +243,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que filterQuery, expectedStrategy ); - case FLOAT -> new ESKnnFloatVectorQuery( + case FLOAT, BFLOAT16 -> new ESKnnFloatVectorQuery( VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, @@ -264,7 +264,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que yield new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD); } } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (filterQuery != null) { yield new BooleanQuery.Builder().add( new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD), diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 044af0ab1d37d..ca35823c452b8 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -255,7 +255,7 @@ private static List generateEmbedding(String input, int dimensions, Dense // Copied from DenseVectorFieldMapperTestUtils due to dependency restrictions private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { return switch (elementType) { - case FLOAT, BYTE -> dimensions; + case FLOAT, BYTE, BFLOAT16 -> dimensions; case BIT -> { assert dimensions % Byte.SIZE == 0; yield dimensions / Byte.SIZE; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index aadbf94cfad11..0d260a557f602 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1295,7 +1295,7 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDense int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(32, m, efConstruction, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, rescoreVector); } static SemanticTextIndexOptions defaultIndexOptions(IndexVersion indexVersionCreated, MinimalServiceSettings modelSettings) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java index 6e1407beab1d8..b59854528758e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java @@ -96,7 +96,7 @@ public TextEmbeddingResults responseBody(SageMakerModel model, InvokeEndpoint return switch (model.apiServiceSettings().elementType()) { case BIT -> TextEmbeddingBinary.PARSER.apply(p, null); case BYTE -> TextEmbeddingBytes.PARSER.apply(p, null); - case FLOAT -> TextEmbeddingFloat.PARSER.apply(p, null); + case FLOAT, BFLOAT16 -> TextEmbeddingFloat.PARSER.apply(p, null); }; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index 175c3e90f798d..9bc1736a85c7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -269,7 +269,7 @@ private static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BYTE, BIT -> randomChunkedInferenceEmbeddingByte(model, inputs); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 74f18e9c45ab1..7ee178cbe2af6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -1317,7 +1317,7 @@ private static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDens int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(32, m, efConstruction, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, rescoreVector); } private static SemanticTextIndexOptions defaultBbqHnswSemanticTextIndexOptions() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index d1499f4009d0a..fb6b5b85c9414 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -190,7 +190,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbedding(Model mo return switch (model.getTaskType()) { case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); }; default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); @@ -222,7 +222,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Model model, List inputs) { DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType(); int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions()); - assert elementType == DenseVectorFieldMapper.ElementType.FLOAT; + assert elementType == DenseVectorFieldMapper.ElementType.FLOAT || elementType == DenseVectorFieldMapper.ElementType.BFLOAT16; List chunks = new ArrayList<>(); for (String input : inputs) { @@ -272,7 +272,7 @@ public static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); @@ -415,7 +415,7 @@ public static ChunkedInference toChunkedResult( ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText); double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType()); EmbeddingResults.Embedding embedding = switch (elementType) { - case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); + case FLOAT, BFLOAT16 -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values)); }; chunks.add(new EmbeddingResults.Chunk(embedding, offset)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index a7eda7112723e..6d54a4dd9eaca 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -267,7 +267,7 @@ private void assertTextEmbeddingLuceneQuery(Query query) { Query innerQuery = assertOuterBooleanQuery(query); Class expectedKnnQueryClass = switch (denseVectorElementType) { - case FLOAT -> KnnFloatVectorQuery.class; + case FLOAT, BFLOAT16 -> KnnFloatVectorQuery.class; case BYTE, BIT -> KnnByteVectorQuery.class; }; assertThat(innerQuery, instanceOf(expectedKnnQueryClass)); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java index b858b935c1483..e303e4bc7e834 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java @@ -16,6 +16,7 @@ import org.elasticsearch.index.fielddata.SortedBinaryDocValues; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; +import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.BitRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; @@ -123,7 +124,47 @@ public Object nextValue() { VectorIterator iterator = new FloatRankVectorsDocValuesField.FloatVectorIterator(ref, vector, numVecs); while (iterator.hasNext()) { float[] v = iterator.next(); - vectors.add(Arrays.copyOf(v, v.length)); + vectors.add(v.clone()); + } + return vectors; + } + }; + case BFLOAT16 -> new FormattedDocValues() { + private final float[] vector = new float[dims]; + private BytesRef ref = null; + private int numVecs = -1; + private final BinaryDocValues binary; + { + try { + binary = DocValues.getBinary(reader, field); + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values", e); + } + } + + @Override + public boolean advanceExact(int docId) throws IOException { + if (binary == null || binary.advanceExact(docId) == false) { + return false; + } + ref = binary.binaryValue(); + assert ref.length % (Short.BYTES * dims) == 0; + numVecs = ref.length / (Short.BYTES * dims); + return true; + } + + @Override + public int docValueCount() { + return 1; + } + + @Override + public Object nextValue() { + List vectors = new ArrayList<>(numVecs); + VectorIterator iterator = new BFloat16RankVectorsDocValuesField.BFloat16VectorIterator(ref, vector, numVecs); + while (iterator.hasNext()) { + float[] v = iterator.next(); + vectors.add(v.clone()); } return vectors; } @@ -140,6 +181,7 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { case BYTE -> new ByteRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case FLOAT -> new FloatRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case BIT -> new BitRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); + case BFLOAT16 -> new BFloat16RankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); }; } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for multi-vector field!", e); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java index 3d692db08ff9e..df47d43c6a002 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java @@ -351,7 +351,7 @@ public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fiel yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list); } } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new MaxSimFloatDotProduct(scoreScript, field, (List>) queryVector); } diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java index 6f80d7b46f563..3e3ee2fd67336 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java @@ -417,7 +417,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { } yield vectors; } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { float[][] vectors = new float[numVectors][vectorFieldType.getVectorDimensions()]; for (int i = 0; i < numVectors; i++) { for (int j = 0; j < vectorFieldType.getVectorDimensions(); j++) { @@ -473,7 +473,7 @@ public SyntheticSourceExample example(int maxValues) { Object value = switch (elementType) { case BYTE, BIT: yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomByte)); - case FLOAT: + case FLOAT, BFLOAT16: yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomFloat)); }; return new SyntheticSourceExample(value, value, this::mapping); diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java index f0b00849557bd..6523596fadfa8 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; +import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.RankVectors; @@ -50,6 +51,29 @@ public void testFloatGetVectorValueAndGetMagnitude() throws IOException { } } } + public void testBFloat16GetVectorValueAndGetMagnitude() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] expectedMagnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(expectedMagnitudes); + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BFLOAT16, dims); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + assertEquals(vectors[i].length, field.size()); + assertEquals(dims, scriptDocValues.dims()); + Iterator iterator = scriptDocValues.getVectorValues(); + float[] magnitudes = scriptDocValues.getMagnitudes(); + assertEquals(expectedMagnitudes[i].length, magnitudes.length); + for (int j = 0; j < vectors[i].length; j++) { + assertTrue(iterator.hasNext()); + assertArrayEquals(vectors[i][j], iterator.next(), 0.0001f); + assertEquals(expectedMagnitudes[i][j], magnitudes[j], 0.0001f); + } + } + } public void testByteGetVectorValueAndGetMagnitude() throws IOException { int dims = 3; @@ -97,6 +121,28 @@ public void testFloatMetadataAndIterator() throws IOException { assertEquals(dv, RankVectors.EMPTY); } + public void testBFloat16MetadataAndIterator() throws IOException { + int dims = 3; + float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.BFLOAT16), fill(new float[2][dims], ElementType.BFLOAT16) }; + float[][] magnitudes = new float[][] { new float[3], new float[2] }; + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(magnitudes); + + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BFLOAT16, dims); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + RankVectors dv = field.get(); + assertEquals(vectors[i].length, dv.size()); + assertFalse(dv.isEmpty()); + assertEquals(dims, dv.getDims()); + UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator); + assertEquals("Cannot iterate over single valued rank_vectors field, use get() instead", e.getMessage()); + } + field.setNextDocId(vectors.length); + RankVectors dv = field.get(); + assertEquals(dv, RankVectors.EMPTY); + } + public void testByteMetadataAndIterator() throws IOException { int dims = 3; float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.BYTE), fill(new float[2][dims], ElementType.BYTE) }; @@ -145,6 +191,24 @@ public void testFloatMissingValues() throws IOException { assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); } + public void testBFloat16MissingValues() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(magnitudes); + RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BFLOAT16, dims); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(3); + assertEquals(0, field.size()); + Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); + + e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); + } + public void testByteMissingValues() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -183,6 +247,26 @@ public void testFloatGetFunctionIsNotAccessible() throws IOException { ); } + public void testBFloat16GetFunctionIsNotAccessible() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(magnitudes); + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BFLOAT16, dims); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(0); + Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); + assertThat( + e.getMessage(), + containsString( + "accessing a rank-vectors field's value through 'get' or 'value' is not supported," + + " use 'vectorValues' or 'magnitudes' instead." + ) + ); + } + public void testByteGetFunctionIsNotAccessible() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -304,12 +388,11 @@ public static BytesRef mockEncodeDenseVector(float[][] values, ElementType eleme ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes * values.length); for (float[] vector : values) { for (float value : vector) { - if (elementType == ElementType.FLOAT) { - byteBuffer.putFloat(value); - } else if (elementType == ElementType.BYTE || elementType == ElementType.BIT) { - byteBuffer.put((byte) value); - } else { - throw new IllegalStateException("unknown element_type [" + elementType + "]"); + switch (elementType) { + case FLOAT -> byteBuffer.putFloat(value); + case BFLOAT16 -> byteBuffer.putShort((short)(Float.floatToIntBits(value) >>> 16)); + case BYTE, BIT -> byteBuffer.put((byte) value); + default -> throw new IllegalStateException("unknown element_type [" + elementType + "]"); } } } From 22e0756a6ed718b07acbcb89820d9f7e099e8cec Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 8 Aug 2025 10:43:21 +0000 Subject: [PATCH 06/17] [CI] Auto commit changes from spotless --- .../mapper/RankVectorsDVLeafFieldData.java | 1 - .../RankVectorsScriptDocValuesTests.java | 39 ++++++++++++++++--- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java index e303e4bc7e834..60d6bfc5586ee 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java @@ -25,7 +25,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; final class RankVectorsDVLeafFieldData implements LeafFieldData { diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java index 6523596fadfa8..8080762f2d492 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java @@ -51,6 +51,7 @@ public void testFloatGetVectorValueAndGetMagnitude() throws IOException { } } } + public void testBFloat16GetVectorValueAndGetMagnitude() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -58,7 +59,13 @@ public void testBFloat16GetVectorValueAndGetMagnitude() throws IOException { BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); BinaryDocValues magnitudeValues = wrap(expectedMagnitudes); - RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BFLOAT16, dims); + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); for (int i = 0; i < vectors.length; i++) { field.setNextDocId(i); @@ -123,12 +130,20 @@ public void testFloatMetadataAndIterator() throws IOException { public void testBFloat16MetadataAndIterator() throws IOException { int dims = 3; - float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.BFLOAT16), fill(new float[2][dims], ElementType.BFLOAT16) }; + float[][][] vectors = new float[][][] { + fill(new float[3][dims], ElementType.BFLOAT16), + fill(new float[2][dims], ElementType.BFLOAT16) }; float[][] magnitudes = new float[][] { new float[3], new float[2] }; BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); BinaryDocValues magnitudeValues = wrap(magnitudes); - RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BFLOAT16, dims); + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); for (int i = 0; i < vectors.length; i++) { field.setNextDocId(i); RankVectors dv = field.get(); @@ -197,7 +212,13 @@ public void testBFloat16MissingValues() throws IOException { float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); BinaryDocValues magnitudeValues = wrap(magnitudes); - RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BFLOAT16, dims); + RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); field.setNextDocId(3); @@ -253,7 +274,13 @@ public void testBFloat16GetFunctionIsNotAccessible() throws IOException { float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); BinaryDocValues magnitudeValues = wrap(magnitudes); - RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BFLOAT16, dims); + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); field.setNextDocId(0); @@ -390,7 +417,7 @@ public static BytesRef mockEncodeDenseVector(float[][] values, ElementType eleme for (float value : vector) { switch (elementType) { case FLOAT -> byteBuffer.putFloat(value); - case BFLOAT16 -> byteBuffer.putShort((short)(Float.floatToIntBits(value) >>> 16)); + case BFLOAT16 -> byteBuffer.putShort((short) (Float.floatToIntBits(value) >>> 16)); case BYTE, BIT -> byteBuffer.put((byte) value); default -> throw new IllegalStateException("unknown element_type [" + elementType + "]"); } From 6ff4792c5b5d5524ac61d5dee7c2478b96f2b483 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 8 Aug 2025 14:22:07 +0100 Subject: [PATCH 07/17] Only run bfloat16 upgrade tests if bfloat16 is actually supported --- .../elasticsearch/index/mapper/MapperFeatures.java | 2 +- .../upgrades/SemanticTextUpgradeIT.java | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 2480a7a373bef..be52ced30bff5 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -47,7 +47,7 @@ public class MapperFeatures implements FeatureSpecification { static final NodeFeature BBQ_DISK_SUPPORT = new NodeFeature("mapper.bbq_disk_support"); static final NodeFeature SEARCH_LOAD_PER_SHARD = new NodeFeature("mapper.search_load_per_shard"); static final NodeFeature PATTERNED_TEXT = new NodeFeature("mapper.patterned_text"); - static final NodeFeature BBQ_BFLOAT16 = new NodeFeature("mapper.bbq_bfloat16"); + public static final NodeFeature BBQ_BFLOAT16 = new NodeFeature("mapper.bbq_bfloat16"); @Override public Set getTestFeatures() { diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java index f21525103d37d..94e33ec61cb3f 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.mapper.MapperFeatures; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; import org.elasticsearch.index.query.NestedQueryBuilder; @@ -85,12 +86,21 @@ public static Iterable parameters() { public void testSemanticTextOperations() throws Exception { switch (CLUSTER_TYPE) { - case OLD -> createAndPopulateIndex(); + case OLD -> { + checkSupportsBFloat16(); + createAndPopulateIndex(); + } case MIXED, UPGRADED -> performIndexQueryHighlightOps(); default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]"); } } + private static void checkSupportsBFloat16() { + assumeTrue("The old cluster needs to support bfloat16 if it is used", + DENSE_MODEL.getServiceSettings().elementType() != DenseVectorFieldMapper.ElementType.BFLOAT16 + || clusterHasFeature(MapperFeatures.BBQ_BFLOAT16)); + } + private void createAndPopulateIndex() throws IOException { final String indexName = getIndexName(); final String mapping = Strings.format(""" From d35e0bcb55f0ec24d5a5eaf322b15a81def538cb Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 8 Aug 2025 13:28:55 +0000 Subject: [PATCH 08/17] [CI] Auto commit changes from spotless --- .../org/elasticsearch/upgrades/SemanticTextUpgradeIT.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java index 94e33ec61cb3f..51cb6a0b58439 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -96,9 +96,11 @@ public void testSemanticTextOperations() throws Exception { } private static void checkSupportsBFloat16() { - assumeTrue("The old cluster needs to support bfloat16 if it is used", + assumeTrue( + "The old cluster needs to support bfloat16 if it is used", DENSE_MODEL.getServiceSettings().elementType() != DenseVectorFieldMapper.ElementType.BFLOAT16 - || clusterHasFeature(MapperFeatures.BBQ_BFLOAT16)); + || clusterHasFeature(MapperFeatures.BBQ_BFLOAT16) + ); } private void createAndPopulateIndex() throws IOException { From b142470ba1af894c161ec3ef22353bfc809b5887 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 8 Aug 2025 15:53:45 +0100 Subject: [PATCH 09/17] Remove unneeded reader/writer --- .../{es818 => }/MergeReaderWrapper.java | 2 +- .../es818/BinarizedByteVectorValues.java | 2 +- .../DirectIOLucene99FlatVectorsFormat.java | 1 + .../es818/ES818BinaryFlatVectorsScorer.java | 4 +- .../ES818BinaryQuantizedVectorsReader.java | 8 +- .../ES818BinaryQuantizedVectorsWriter.java | 22 +- .../es818/OffHeapBinarizedVectorValues.java | 8 +- .../es92/ES92BFloat16FlatVectorsFormat.java | 2 +- ...2BinaryQuantizedBFloat16VectorsFormat.java | 6 +- ...2BinaryQuantizedBFloat16VectorsReader.java | 391 ---------- ...2BinaryQuantizedBFloat16VectorsWriter.java | 726 ------------------ 11 files changed, 27 insertions(+), 1145 deletions(-) rename server/src/main/java/org/elasticsearch/index/codec/vectors/{es818 => }/MergeReaderWrapper.java (98%) delete mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsReader.java delete mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsWriter.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/MergeReaderWrapper.java similarity index 98% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/MergeReaderWrapper.java index 915f3b4ced18e..afc2ca77dda2e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/MergeReaderWrapper.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.index.codec.vectors.es818; +package org.elasticsearch.index.codec.vectors; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.index.ByteVectorValues; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java index 53867d9d1e494..ca80ba52e2c2b 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java @@ -30,7 +30,7 @@ /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ -public abstract class BinarizedByteVectorValues extends ByteVectorValues { +abstract class BinarizedByteVectorValues extends ByteVectorValues { /** * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java index 8e328b5c500ad..c907bac2a2e13 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java @@ -31,6 +31,7 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.MergeInfo; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index bbc3d328d7767..efb098373489f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -108,7 +108,7 @@ public RandomVectorScorer getRandomVectorScorer( return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } - public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, BinarizedByteVectorValues targetVectors @@ -122,7 +122,7 @@ public String toString() { } /** Vector scorer supplier over binarized vector values */ - public static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { private final ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; private final BinarizedByteVectorValues targetVectors; private final VectorSimilarityFunction similarityFunction; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java index 0a59a69416182..2d3699d915f81 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java @@ -71,7 +71,7 @@ public class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader { private final ES818BinaryFlatVectorsScorer vectorScorer; @SuppressWarnings("this-escape") - ES818BinaryQuantizedVectorsReader( + public ES818BinaryQuantizedVectorsReader( SegmentReadState state, FlatVectorsReader rawVectorsReader, ES818BinaryFlatVectorsScorer vectorsScorer @@ -388,11 +388,11 @@ static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, Vector } /** Binarized vector values holding row and quantized vector values */ - public static final class BinarizedVectorValues extends FloatVectorValues { + protected static final class BinarizedVectorValues extends FloatVectorValues { private final FloatVectorValues rawVectorValues; private final BinarizedByteVectorValues quantizedVectorValues; - public BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { + BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { this.rawVectorValues = rawVectorValues; this.quantizedVectorValues = quantizedVectorValues; } @@ -437,7 +437,7 @@ public VectorScorer scorer(float[] query) throws IOException { return quantizedVectorValues.scorer(query); } - public BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { + BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { return quantizedVectorValues; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java index 12cf4742b6b47..abceb15f62ad5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -85,7 +85,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { * @param vectorsScorer the scorer to use for scoring vectors */ @SuppressWarnings("this-escape") - protected ES818BinaryQuantizedVectorsWriter( + public ES818BinaryQuantizedVectorsWriter( ES818BinaryFlatVectorsScorer vectorsScorer, FlatVectorsWriter rawVectorDelegate, SegmentWriteState state @@ -723,7 +723,7 @@ public long ramBytesUsed() { } // When accessing vectorValue method, targerOrd here means a row ordinal. - public static class OffHeapBinarizedQueryVectorValues { + static class OffHeapBinarizedQueryVectorValues { private final IndexInput slice; private final int dimension; private final int size; @@ -734,7 +734,7 @@ public static class OffHeapBinarizedQueryVectorValues { private int lastOrd = -1; private int quantizedComponentSum; - public OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) { + OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) { this.slice = data; this.dimension = dimension; this.size = size; @@ -798,7 +798,7 @@ public byte[] vectorValue(int targetOrd) throws IOException { } } - public static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { + static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { private OptimizedScalarQuantizer.QuantizationResult corrections; private final byte[] binarized; private final int[] initQuantized; @@ -808,7 +808,7 @@ public static class BinarizedFloatVectorValues extends BinarizedByteVectorValues private int lastOrd = -1; - public BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { + BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { this.values = delegate; this.quantizer = quantizer; this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8]; @@ -881,16 +881,12 @@ public int ordToDoc(int ord) { } } - public static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { private final RandomVectorScorerSupplier supplier; private final KnnVectorValues vectorValues; private final Closeable onClose; - public BinarizedCloseableRandomVectorScorerSupplier( - RandomVectorScorerSupplier supplier, - KnnVectorValues vectorValues, - Closeable onClose - ) { + BinarizedCloseableRandomVectorScorerSupplier(RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { this.supplier = supplier; this.onClose = onClose; this.vectorValues = vectorValues; @@ -917,11 +913,11 @@ public int totalVectorCount() { } } - public static final class NormalizedFloatVectorValues extends FloatVectorValues { + static final class NormalizedFloatVectorValues extends FloatVectorValues { private final FloatVectorValues values; private final float[] normalizedVector; - public NormalizedFloatVectorValues(FloatVectorValues values) { + NormalizedFloatVectorValues(FloatVectorValues values) { this.values = values; this.normalizedVector = new float[values.dimension()]; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java index 8d1474b7ecf31..0357468c6864d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java @@ -36,7 +36,7 @@ import java.nio.ByteBuffer; /** Binarized vector values loaded from off-heap */ -public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { +abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { final int dimension; final int size; @@ -151,7 +151,7 @@ public int getVectorByteLength() { return numBytes; } - public static OffHeapBinarizedVectorValues load( + static OffHeapBinarizedVectorValues load( OrdToDocDISIReaderConfiguration configuration, int dimension, int size, @@ -197,8 +197,8 @@ public static OffHeapBinarizedVectorValues load( } /** Dense off-heap binarized vector values */ - public static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { - public DenseOffHeapVectorValues( + static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + DenseOffHeapVectorValues( int dimension, int size, float[] centroid, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java index 859001dd218b2..1b134e9b45242 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java @@ -31,7 +31,7 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es818.MergeReaderWrapper; +import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java index 642ec6807a93c..c029cecb2200b 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java @@ -27,6 +27,8 @@ import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; import java.io.IOException; @@ -113,12 +115,12 @@ public ES92BinaryQuantizedBFloat16VectorsFormat() { @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ES92BinaryQuantizedBFloat16VectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); } @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new ES92BinaryQuantizedBFloat16VectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); + return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsReader.java deleted file mode 100644 index d3f43851fe488..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsReader.java +++ /dev/null @@ -1,391 +0,0 @@ -/* - * @notice - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * Modifications copyright (C) 2024 Elasticsearch B.V. - */ -package org.elasticsearch.index.codec.vectors.es92; - -import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.hnsw.FlatVectorsReader; -import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.CorruptIndexException; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FieldInfos; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.IndexFileNames; -import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.KnnCollector; -import org.apache.lucene.store.ChecksumIndexInput; -import org.apache.lucene.store.DataAccessHint; -import org.apache.lucene.store.FileDataHint; -import org.apache.lucene.store.FileTypeHint; -import org.apache.lucene.store.IOContext; -import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.Bits; -import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.RamUsageEstimator; -import org.apache.lucene.util.SuppressForbidden; -import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; -import org.apache.lucene.util.hnsw.RandomVectorScorer; -import org.elasticsearch.index.codec.vectors.BQVectorUtils; -import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; -import org.elasticsearch.index.codec.vectors.es818.OffHeapBinarizedVectorValues; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; -import static org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat.VECTOR_DATA_EXTENSION; - -/** - * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 - */ -@SuppressForbidden(reason = "Lucene classes") -public class ES92BinaryQuantizedBFloat16VectorsReader extends FlatVectorsReader { - - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES92BinaryQuantizedBFloat16VectorsReader.class); - - private final Map fields; - private final IndexInput quantizedVectorData; - private final FlatVectorsReader rawVectorsReader; - private final ES818BinaryFlatVectorsScorer vectorScorer; - - @SuppressWarnings("this-escape") - ES92BinaryQuantizedBFloat16VectorsReader( - SegmentReadState state, - FlatVectorsReader rawVectorsReader, - ES818BinaryFlatVectorsScorer vectorsScorer - ) throws IOException { - super(vectorsScorer); - this.fields = new HashMap<>(); - this.vectorScorer = vectorsScorer; - this.rawVectorsReader = rawVectorsReader; - int versionMeta = -1; - String metaFileName = IndexFileNames.segmentFileName( - state.segmentInfo.name, - state.segmentSuffix, - ES92BinaryQuantizedBFloat16VectorsFormat.META_EXTENSION - ); - boolean success = false; - try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { - Throwable priorE = null; - try { - versionMeta = CodecUtil.checkIndexHeader( - meta, - ES92BinaryQuantizedBFloat16VectorsFormat.META_CODEC_NAME, - ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_START, - ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix - ); - readFields(meta, state.fieldInfos); - } catch (Throwable exception) { - priorE = exception; - } finally { - CodecUtil.checkFooter(meta, priorE); - } - quantizedVectorData = openDataInput( - state, - versionMeta, - VECTOR_DATA_EXTENSION, - ES92BinaryQuantizedBFloat16VectorsFormat.VECTOR_DATA_CODEC_NAME, - // Quantized vectors are accessed randomly from their node ID stored in the HNSW - // graph. - state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) - ); - success = true; - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(this); - } - } - } - - private ES92BinaryQuantizedBFloat16VectorsReader(ES92BinaryQuantizedBFloat16VectorsReader clone, FlatVectorsReader rawVectorsReader) { - super(clone.vectorScorer); - this.rawVectorsReader = rawVectorsReader; - this.vectorScorer = clone.vectorScorer; - this.quantizedVectorData = clone.quantizedVectorData; - this.fields = clone.fields; - } - - @Override - public FlatVectorsReader getMergeInstance() { - return new ES92BinaryQuantizedBFloat16VectorsReader(this, rawVectorsReader.getMergeInstance()); - } - - private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { - for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { - FieldInfo info = infos.fieldInfo(fieldNumber); - if (info == null) { - throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); - } - FieldEntry fieldEntry = readField(meta, info); - validateFieldEntry(info, fieldEntry); - fields.put(info.name, fieldEntry); - } - } - - static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { - int dimension = info.getVectorDimension(); - if (dimension != fieldEntry.dimension) { - throw new IllegalStateException( - "Inconsistent vector dimension for field=\"" + info.name + "\"; " + dimension + " != " + fieldEntry.dimension - ); - } - - int binaryDims = BQVectorUtils.discretize(dimension, 64) / 8; - long numQuantizedVectorBytes = Math.multiplyExact((binaryDims + (Float.BYTES * 3) + Short.BYTES), (long) fieldEntry.size); - if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { - throw new IllegalStateException( - "Binarized vector data length " - + fieldEntry.vectorDataLength - + " not matching size = " - + fieldEntry.size - + " * (binaryBytes=" - + binaryDims - + " + 14" - + ") = " - + numQuantizedVectorBytes - ); - } - } - - @Override - public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { - FieldEntry fi = fields.get(field); - if (fi == null || fi.size() == 0) { - return null; - } - return vectorScorer.getRandomVectorScorer( - fi.similarityFunction, - OffHeapBinarizedVectorValues.load( - fi.ordToDocDISIReaderConfiguration, - fi.dimension, - fi.size, - new OptimizedScalarQuantizer(fi.similarityFunction), - fi.similarityFunction, - vectorScorer, - fi.centroid, - fi.centroidDP, - fi.vectorDataOffset, - fi.vectorDataLength, - quantizedVectorData - ), - target - ); - } - - @Override - public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { - return rawVectorsReader.getRandomVectorScorer(field, target); - } - - @Override - public void checkIntegrity() throws IOException { - rawVectorsReader.checkIntegrity(); - CodecUtil.checksumEntireFile(quantizedVectorData); - } - - @Override - public FloatVectorValues getFloatVectorValues(String field) throws IOException { - FieldEntry fi = fields.get(field); - if (fi == null) { - return null; - } - if (fi.vectorEncoding != VectorEncoding.FLOAT32) { - throw new IllegalArgumentException( - "field=\"" + field + "\" is encoded as: " + fi.vectorEncoding + " expected: " + VectorEncoding.FLOAT32 - ); - } - OffHeapBinarizedVectorValues bvv = OffHeapBinarizedVectorValues.load( - fi.ordToDocDISIReaderConfiguration, - fi.dimension, - fi.size, - new OptimizedScalarQuantizer(fi.similarityFunction), - fi.similarityFunction, - vectorScorer, - fi.centroid, - fi.centroidDP, - fi.vectorDataOffset, - fi.vectorDataLength, - quantizedVectorData - ); - return new ES818BinaryQuantizedVectorsReader.BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv); - } - - @Override - public ByteVectorValues getByteVectorValues(String field) throws IOException { - return rawVectorsReader.getByteVectorValues(field); - } - - @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - rawVectorsReader.search(field, target, knnCollector, acceptDocs); - } - - @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - if (knnCollector.k() == 0) return; - final RandomVectorScorer scorer = getRandomVectorScorer(field, target); - if (scorer == null) return; - OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); - Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); - for (int i = 0; i < scorer.maxOrd(); i++) { - if (acceptedOrds == null || acceptedOrds.get(i)) { - collector.collect(i, scorer.score(i)); - collector.incVisitedCount(1); - } - } - } - - @Override - public void close() throws IOException { - IOUtils.close(quantizedVectorData, rawVectorsReader); - } - - @Override - public long ramBytesUsed() { - long size = SHALLOW_SIZE; - size += RamUsageEstimator.sizeOfMap(fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); - size += rawVectorsReader.ramBytesUsed(); - return size; - } - - @Override - public Map getOffHeapByteSize(FieldInfo fieldInfo) { - var raw = rawVectorsReader.getOffHeapByteSize(fieldInfo); - FieldEntry fe = fields.get(fieldInfo.name); - if (fe == null) { - assert fieldInfo.getVectorEncoding() == VectorEncoding.BYTE; - return raw; - } - var quant = Map.of(VECTOR_DATA_EXTENSION, fe.vectorDataLength()); - return KnnVectorsReader.mergeOffHeapByteSizeMaps(raw, quant); - } - - public float[] getCentroid(String field) { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry != null) { - return fieldEntry.centroid; - } - return null; - } - - private static IndexInput openDataInput( - SegmentReadState state, - int versionMeta, - String fileExtension, - String codecName, - IOContext context - ) throws IOException { - String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); - IndexInput in = state.directory.openInput(fileName, context); - boolean success = false; - try { - int versionVectorData = CodecUtil.checkIndexHeader( - in, - codecName, - ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_START, - ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix - ); - if (versionMeta != versionVectorData) { - throw new CorruptIndexException( - "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, - in - ); - } - CodecUtil.retrieveChecksum(in); - success = true; - return in; - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(in); - } - } - } - - private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { - VectorEncoding vectorEncoding = readVectorEncoding(input); - VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); - if (similarityFunction != info.getVectorSimilarityFunction()) { - throw new IllegalStateException( - "Inconsistent vector similarity function for field=\"" - + info.name - + "\"; " - + similarityFunction - + " != " - + info.getVectorSimilarityFunction() - ); - } - return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); - } - - private record FieldEntry( - VectorSimilarityFunction similarityFunction, - VectorEncoding vectorEncoding, - int dimension, - int descritizedDimension, - long vectorDataOffset, - long vectorDataLength, - int size, - float[] centroid, - float centroidDP, - OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration - ) { - - static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) - throws IOException { - int dimension = input.readVInt(); - long vectorDataOffset = input.readVLong(); - long vectorDataLength = input.readVLong(); - int size = input.readVInt(); - final float[] centroid; - float centroidDP = 0; - if (size > 0) { - centroid = new float[dimension]; - input.readFloats(centroid, 0, dimension); - centroidDP = Float.intBitsToFloat(input.readInt()); - } else { - centroid = null; - } - OrdToDocDISIReaderConfiguration conf = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); - return new FieldEntry( - similarityFunction, - vectorEncoding, - dimension, - BQVectorUtils.discretize(dimension, 64), - vectorDataOffset, - vectorDataLength, - size, - centroid, - centroidDP, - conf - ); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsWriter.java deleted file mode 100644 index c071029d6e99d..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsWriter.java +++ /dev/null @@ -1,726 +0,0 @@ -/* - * @notice - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * Modifications copyright (C) 2024 Elasticsearch B.V. - */ -package org.elasticsearch.index.codec.vectors.es92; - -import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; -import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; -import org.apache.lucene.index.DocsWithFieldSet; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.IndexFileNames; -import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.index.MergeState; -import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.Sorter; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.internal.hppc.FloatArrayList; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; -import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.elasticsearch.core.SuppressForbidden; -import org.elasticsearch.index.codec.vectors.BQSpaceUtils; -import org.elasticsearch.index.codec.vectors.BQVectorUtils; -import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; -import org.elasticsearch.index.codec.vectors.es818.BinarizedByteVectorValues; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; -import org.elasticsearch.index.codec.vectors.es818.OffHeapBinarizedVectorValues; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; -import static org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat.BINARIZED_VECTOR_COMPONENT; -import static org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; - -/** - * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 - */ -@SuppressForbidden(reason = "Lucene classes") -public class ES92BinaryQuantizedBFloat16VectorsWriter extends FlatVectorsWriter { - private static final long SHALLOW_RAM_BYTES_USED = shallowSizeOfInstance(ES92BinaryQuantizedBFloat16VectorsWriter.class); - - private final SegmentWriteState segmentWriteState; - private final List fields = new ArrayList<>(); - private final IndexOutput meta, binarizedVectorData; - private final FlatVectorsWriter rawVectorDelegate; - private final ES818BinaryFlatVectorsScorer vectorsScorer; - private boolean finished; - - /** - * Sole constructor - * - * @param vectorsScorer the scorer to use for scoring vectors - */ - @SuppressWarnings("this-escape") - protected ES92BinaryQuantizedBFloat16VectorsWriter( - ES818BinaryFlatVectorsScorer vectorsScorer, - FlatVectorsWriter rawVectorDelegate, - SegmentWriteState state - ) throws IOException { - super(vectorsScorer); - this.vectorsScorer = vectorsScorer; - this.segmentWriteState = state; - String metaFileName = IndexFileNames.segmentFileName( - state.segmentInfo.name, - state.segmentSuffix, - ES92BinaryQuantizedBFloat16VectorsFormat.META_EXTENSION - ); - - String binarizedVectorDataFileName = IndexFileNames.segmentFileName( - state.segmentInfo.name, - state.segmentSuffix, - ES92BinaryQuantizedBFloat16VectorsFormat.VECTOR_DATA_EXTENSION - ); - this.rawVectorDelegate = rawVectorDelegate; - boolean success = false; - try { - meta = state.directory.createOutput(metaFileName, state.context); - binarizedVectorData = state.directory.createOutput(binarizedVectorDataFileName, state.context); - - CodecUtil.writeIndexHeader( - meta, - ES92BinaryQuantizedBFloat16VectorsFormat.META_CODEC_NAME, - ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix - ); - CodecUtil.writeIndexHeader( - binarizedVectorData, - ES92BinaryQuantizedBFloat16VectorsFormat.VECTOR_DATA_CODEC_NAME, - ES92BinaryQuantizedBFloat16VectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix - ); - success = true; - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(this); - } - } - } - - @Override - public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { - FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); - if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - @SuppressWarnings("unchecked") - FieldWriter fieldWriter = new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); - fields.add(fieldWriter); - return fieldWriter; - } - return rawVectorDelegate; - } - - @Override - public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { - rawVectorDelegate.flush(maxDoc, sortMap); - for (FieldWriter field : fields) { - // after raw vectors are written, normalize vectors for clustering and quantization - if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { - field.normalizeVectors(); - } - final float[] clusterCenter; - int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); - clusterCenter = new float[field.dimensionSums.length]; - if (vectorCount > 0) { - for (int i = 0; i < field.dimensionSums.length; i++) { - clusterCenter[i] = field.dimensionSums[i] / vectorCount; - } - if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { - VectorUtil.l2normalize(clusterCenter); - } - } - if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { - segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); - } - OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction()); - if (sortMap == null) { - writeField(field, clusterCenter, maxDoc, quantizer); - } else { - writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); - } - field.finish(); - } - } - - private void writeField(FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer) - throws IOException { - // write vector values - long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); - writeBinarizedVectors(fieldData, clusterCenter, quantizer); - long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; - float centroidDp = fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; - - writeMeta( - fieldData.fieldInfo, - maxDoc, - vectorDataOffset, - vectorDataLength, - clusterCenter, - centroidDp, - fieldData.getDocsWithFieldSet() - ); - } - - private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) - throws IOException { - int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); - int[] quantizationScratch = new int[discreteDims]; - byte[] vector = new byte[discreteDims / 8]; - for (int i = 0; i < fieldData.getVectors().size(); i++) { - float[] v = fieldData.getVectors().get(i); - OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( - v, - quantizationScratch, - (byte) 1, - clusterCenter - ); - BQVectorUtils.packAsBinary(quantizationScratch, vector); - binarizedVectorData.writeBytes(vector, vector.length); - binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); - binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); - binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); - assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; - binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); - } - } - - private void writeSortingField( - FieldWriter fieldData, - float[] clusterCenter, - int maxDoc, - Sorter.DocMap sortMap, - OptimizedScalarQuantizer scalarQuantizer - ) throws IOException { - final int[] ordMap = new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord - - DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); - mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); - - // write vector values - long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); - writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); - long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset; - - float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); - writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, quantizedVectorLength, clusterCenter, centroidDp, newDocsWithField); - } - - private void writeSortedBinarizedVectors( - FieldWriter fieldData, - float[] clusterCenter, - int[] ordMap, - OptimizedScalarQuantizer scalarQuantizer - ) throws IOException { - int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); - int[] quantizationScratch = new int[discreteDims]; - byte[] vector = new byte[discreteDims / 8]; - for (int ordinal : ordMap) { - float[] v = fieldData.getVectors().get(ordinal); - OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( - v, - quantizationScratch, - (byte) 1, - clusterCenter - ); - BQVectorUtils.packAsBinary(quantizationScratch, vector); - binarizedVectorData.writeBytes(vector, vector.length); - binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); - binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); - binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); - assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; - binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); - } - } - - private void writeMeta( - FieldInfo field, - int maxDoc, - long vectorDataOffset, - long vectorDataLength, - float[] clusterCenter, - float centroidDp, - DocsWithFieldSet docsWithField - ) throws IOException { - meta.writeInt(field.number); - meta.writeInt(field.getVectorEncoding().ordinal()); - meta.writeInt(field.getVectorSimilarityFunction().ordinal()); - meta.writeVInt(field.getVectorDimension()); - meta.writeVLong(vectorDataOffset); - meta.writeVLong(vectorDataLength); - int count = docsWithField.cardinality(); - meta.writeVInt(count); - if (count > 0) { - final ByteBuffer buffer = ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - buffer.asFloatBuffer().put(clusterCenter); - meta.writeBytes(buffer.array(), buffer.array().length); - meta.writeInt(Float.floatToIntBits(centroidDp)); - } - OrdToDocDISIReaderConfiguration.writeStoredMeta( - DIRECT_MONOTONIC_BLOCK_SHIFT, - meta, - binarizedVectorData, - count, - maxDoc, - docsWithField - ); - } - - @Override - public void finish() throws IOException { - if (finished) { - throw new IllegalStateException("already finished"); - } - finished = true; - rawVectorDelegate.finish(); - if (meta != null) { - // write end of fields marker - meta.writeInt(-1); - CodecUtil.writeFooter(meta); - } - if (binarizedVectorData != null) { - CodecUtil.writeFooter(binarizedVectorData); - } - } - - @Override - public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { - if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - final float[] centroid; - final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; - int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); - // Don't need access to the random vectors, we can just use the merged - rawVectorDelegate.mergeOneField(fieldInfo, mergeState); - centroid = mergedCentroid; - if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { - segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); - } - FloatVectorValues floatVectorValues = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - if (fieldInfo.getVectorSimilarityFunction() == COSINE) { - floatVectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(floatVectorValues); - } - ES818BinaryQuantizedVectorsWriter.BinarizedFloatVectorValues binarizedVectorValues = - new ES818BinaryQuantizedVectorsWriter.BinarizedFloatVectorValues( - floatVectorValues, - new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), - centroid - ); - long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); - DocsWithFieldSet docsWithField = writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues); - long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; - float centroidDp = docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; - writeMeta( - fieldInfo, - segmentWriteState.segmentInfo.maxDoc(), - vectorDataOffset, - vectorDataLength, - centroid, - centroidDp, - docsWithField - ); - } else { - rawVectorDelegate.mergeOneField(fieldInfo, mergeState); - } - } - - static DocsWithFieldSet writeBinarizedVectorAndQueryData( - IndexOutput binarizedVectorData, - IndexOutput binarizedQueryData, - FloatVectorValues floatVectorValues, - float[] centroid, - OptimizedScalarQuantizer binaryQuantizer - ) throws IOException { - int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64); - DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - int[][] quantizationScratch = new int[2][floatVectorValues.dimension()]; - byte[] toIndex = new byte[discretizedDimension / 8]; - byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY]; - KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); - for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { - // write index vector - OptimizedScalarQuantizer.QuantizationResult[] r = binaryQuantizer.multiScalarQuantize( - floatVectorValues.vectorValue(iterator.index()), - quantizationScratch, - new byte[] { 1, 4 }, - centroid - ); - // pack and store document bit vector - BQVectorUtils.packAsBinary(quantizationScratch[0], toIndex); - binarizedVectorData.writeBytes(toIndex, toIndex.length); - binarizedVectorData.writeInt(Float.floatToIntBits(r[0].lowerInterval())); - binarizedVectorData.writeInt(Float.floatToIntBits(r[0].upperInterval())); - binarizedVectorData.writeInt(Float.floatToIntBits(r[0].additionalCorrection())); - assert r[0].quantizedComponentSum() >= 0 && r[0].quantizedComponentSum() <= 0xffff; - binarizedVectorData.writeShort((short) r[0].quantizedComponentSum()); - docsWithField.add(docV); - - // pack and store the 4bit query vector - BQSpaceUtils.transposeHalfByte(quantizationScratch[1], toQuery); - binarizedQueryData.writeBytes(toQuery, toQuery.length); - binarizedQueryData.writeInt(Float.floatToIntBits(r[1].lowerInterval())); - binarizedQueryData.writeInt(Float.floatToIntBits(r[1].upperInterval())); - binarizedQueryData.writeInt(Float.floatToIntBits(r[1].additionalCorrection())); - assert r[1].quantizedComponentSum() >= 0 && r[1].quantizedComponentSum() <= 0xffff; - binarizedQueryData.writeShort((short) r[1].quantizedComponentSum()); - } - return docsWithField; - } - - static DocsWithFieldSet writeBinarizedVectorData(IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) - throws IOException { - DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator(); - for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { - // write vector - byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index()); - output.writeBytes(binaryValue, binaryValue.length); - OptimizedScalarQuantizer.QuantizationResult corrections = binarizedByteVectorValues.getCorrectiveTerms(iterator.index()); - output.writeInt(Float.floatToIntBits(corrections.lowerInterval())); - output.writeInt(Float.floatToIntBits(corrections.upperInterval())); - output.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); - assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; - output.writeShort((short) corrections.quantizedComponentSum()); - docsWithField.add(docV); - } - return docsWithField; - } - - @Override - public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { - if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - final float[] centroid; - final float cDotC; - final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; - int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); - - // Don't need access to the random vectors, we can just use the merged - rawVectorDelegate.mergeOneField(fieldInfo, mergeState); - centroid = mergedCentroid; - cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; - if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { - segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); - } - return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC); - } - return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); - } - - private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( - SegmentWriteState segmentWriteState, - FieldInfo fieldInfo, - MergeState mergeState, - float[] centroid, - float cDotC - ) throws IOException { - long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); - IndexInput binarizedDataInput = null; - IndexInput binarizedScoreDataInput = null; - IndexOutput tempQuantizedVectorData = null; - IndexOutput tempScoreQuantizedVectorData = null; - boolean success = false; - OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - try { - // Since we are opening two files, it's possible that one or the other fails to open - // we open them within the try to ensure they are cleaned - tempQuantizedVectorData = segmentWriteState.directory.createTempOutput( - binarizedVectorData.getName(), - "temp", - segmentWriteState.context - ); - tempScoreQuantizedVectorData = segmentWriteState.directory.createTempOutput( - binarizedVectorData.getName(), - "score_temp", - segmentWriteState.context - ); - final String tempQuantizedVectorDataName = tempQuantizedVectorData.getName(); - final String tempScoreQuantizedVectorDataName = tempScoreQuantizedVectorData.getName(); - FloatVectorValues floatVectorValues = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - if (fieldInfo.getVectorSimilarityFunction() == COSINE) { - floatVectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(floatVectorValues); - } - DocsWithFieldSet docsWithField = writeBinarizedVectorAndQueryData( - tempQuantizedVectorData, - tempScoreQuantizedVectorData, - floatVectorValues, - centroid, - quantizer - ); - CodecUtil.writeFooter(tempQuantizedVectorData); - IOUtils.close(tempQuantizedVectorData); - binarizedDataInput = segmentWriteState.directory.openInput(tempQuantizedVectorData.getName(), segmentWriteState.context); - binarizedVectorData.copyBytes(binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength()); - long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; - CodecUtil.retrieveChecksum(binarizedDataInput); - CodecUtil.writeFooter(tempScoreQuantizedVectorData); - IOUtils.close(tempScoreQuantizedVectorData); - binarizedScoreDataInput = segmentWriteState.directory.openInput( - tempScoreQuantizedVectorData.getName(), - segmentWriteState.context - ); - writeMeta( - fieldInfo, - segmentWriteState.segmentInfo.maxDoc(), - vectorDataOffset, - vectorDataLength, - centroid, - cDotC, - docsWithField - ); - success = true; - final IndexInput finalBinarizedDataInput = binarizedDataInput; - final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput; - OffHeapBinarizedVectorValues vectorValues = new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues( - fieldInfo.getVectorDimension(), - docsWithField.cardinality(), - centroid, - cDotC, - quantizer, - fieldInfo.getVectorSimilarityFunction(), - vectorsScorer, - finalBinarizedDataInput - ); - RandomVectorScorerSupplier scorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( - fieldInfo.getVectorSimilarityFunction(), - new ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues( - finalBinarizedScoreDataInput, - fieldInfo.getVectorDimension(), - docsWithField.cardinality() - ), - vectorValues - ); - return new ES818BinaryQuantizedVectorsWriter.BinarizedCloseableRandomVectorScorerSupplier(scorerSupplier, vectorValues, () -> { - IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput); - IOUtils.deleteFilesIgnoringExceptions( - segmentWriteState.directory, - tempQuantizedVectorDataName, - tempScoreQuantizedVectorDataName - ); - }); - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException( - tempQuantizedVectorData, - tempScoreQuantizedVectorData, - binarizedDataInput, - binarizedScoreDataInput - ); - if (tempQuantizedVectorData != null) { - IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempQuantizedVectorData.getName()); - } - if (tempScoreQuantizedVectorData != null) { - IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempScoreQuantizedVectorData.getName()); - } - } - } - } - - @Override - public void close() throws IOException { - IOUtils.close(meta, binarizedVectorData, rawVectorDelegate); - } - - static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { - vectorsReader = candidateReader.getFieldReader(fieldName); - } - if (vectorsReader instanceof ES818BinaryQuantizedVectorsReader reader) { - return reader.getCentroid(fieldName); - } - return null; - } - - static int mergeAndRecalculateCentroids(MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { - boolean recalculate = false; - int totalVectorCount = 0; - for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { - KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; - if (knnVectorsReader == null || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { - continue; - } - float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); - int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); - if (vectorCount == 0) { - continue; - } - totalVectorCount += vectorCount; - // If there aren't centroids, or previously clustered with more than one cluster - // or if there are deleted docs, we must recalculate the centroid - if (centroid == null || mergeState.liveDocs[i] != null) { - recalculate = true; - break; - } - for (int j = 0; j < centroid.length; j++) { - mergedCentroid[j] += centroid[j] * vectorCount; - } - } - if (recalculate) { - return calculateCentroid(mergeState, fieldInfo, mergedCentroid); - } else { - for (int j = 0; j < mergedCentroid.length; j++) { - mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; - } - if (fieldInfo.getVectorSimilarityFunction() == COSINE) { - VectorUtil.l2normalize(mergedCentroid); - } - return totalVectorCount; - } - } - - static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) throws IOException { - assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); - // clear out the centroid - Arrays.fill(centroid, 0); - int count = 0; - for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { - KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; - if (knnVectorsReader == null) continue; - FloatVectorValues vectorValues = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); - if (vectorValues == null) { - continue; - } - KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); - for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { - ++count; - float[] vector = vectorValues.vectorValue(iterator.index()); - // TODO Panama sum - for (int j = 0; j < vector.length; j++) { - centroid[j] += vector[j]; - } - } - } - if (count == 0) { - return count; - } - // TODO Panama div - for (int i = 0; i < centroid.length; i++) { - centroid[i] /= count; - } - if (fieldInfo.getVectorSimilarityFunction() == COSINE) { - VectorUtil.l2normalize(centroid); - } - return count; - } - - @Override - public long ramBytesUsed() { - long total = SHALLOW_RAM_BYTES_USED; - for (FieldWriter field : fields) { - // the field tracks the delegate field usage - total += field.ramBytesUsed(); - } - return total; - } - - static class FieldWriter extends FlatFieldVectorsWriter { - private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); - private final FieldInfo fieldInfo; - private boolean finished; - private final FlatFieldVectorsWriter flatFieldVectorsWriter; - private final float[] dimensionSums; - private final FloatArrayList magnitudes = new FloatArrayList(); - - FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { - this.fieldInfo = fieldInfo; - this.flatFieldVectorsWriter = flatFieldVectorsWriter; - this.dimensionSums = new float[fieldInfo.getVectorDimension()]; - } - - @Override - public List getVectors() { - return flatFieldVectorsWriter.getVectors(); - } - - public void normalizeVectors() { - for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { - float[] vector = flatFieldVectorsWriter.getVectors().get(i); - float magnitude = magnitudes.get(i); - for (int j = 0; j < vector.length; j++) { - vector[j] /= magnitude; - } - } - } - - @Override - public DocsWithFieldSet getDocsWithFieldSet() { - return flatFieldVectorsWriter.getDocsWithFieldSet(); - } - - @Override - public void finish() throws IOException { - if (finished) { - return; - } - assert flatFieldVectorsWriter.isFinished(); - finished = true; - } - - @Override - public boolean isFinished() { - return finished && flatFieldVectorsWriter.isFinished(); - } - - @Override - public void addValue(int docID, float[] vectorValue) throws IOException { - flatFieldVectorsWriter.addValue(docID, vectorValue); - if (fieldInfo.getVectorSimilarityFunction() == COSINE) { - float dp = VectorUtil.dotProduct(vectorValue, vectorValue); - float divisor = (float) Math.sqrt(dp); - magnitudes.add(divisor); - for (int i = 0; i < vectorValue.length; i++) { - dimensionSums[i] += (vectorValue[i] / divisor); - } - } else { - for (int i = 0; i < vectorValue.length; i++) { - dimensionSums[i] += vectorValue[i]; - } - } - } - - @Override - public float[] copyValue(float[] vectorValue) { - throw new UnsupportedOperationException(); - } - - @Override - public long ramBytesUsed() { - long size = SHALLOW_SIZE; - size += flatFieldVectorsWriter.ramBytesUsed(); - size += magnitudes.ramBytesUsed(); - return size; - } - } -} From ec74d8724c81ecffe7ea31b522ba89387c20f843 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 8 Aug 2025 15:01:31 +0000 Subject: [PATCH 10/17] [CI] Auto commit changes from spotless --- .../index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java index 1b134e9b45242..7097d71ecc2fe 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java @@ -29,9 +29,9 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.MergeInfo; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; From 2f4c21f183b93e47ada189ffa78dea7a2036a936 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 8 Aug 2025 16:38:20 +0100 Subject: [PATCH 11/17] Remove byte vector support --- .../es92/ES92BFloat16FlatVectorsReader.java | 40 ++------ .../es92/ES92BFloat16FlatVectorsWriter.java | 97 +++++-------------- ...ryQuantizedBFloat16VectorsFormatTests.java | 63 ------------ 3 files changed, 33 insertions(+), 167 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java index dac2113ccd9ef..99d7302e4bc56 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java @@ -22,7 +22,6 @@ import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; -import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.CorruptIndexException; @@ -221,17 +220,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); - return OffHeapByteVectorValues.load( - fieldEntry.similarityFunction, - vectorScorer, - fieldEntry.ordToDoc, - fieldEntry.vectorEncoding, - fieldEntry.dimension, - fieldEntry.vectorDataOffset, - fieldEntry.vectorDataLength, - vectorData - ); + throw new IllegalStateException(field + " only supports float vectors"); } @Override @@ -256,21 +245,7 @@ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) th @Override public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { - final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); - return vectorScorer.getRandomVectorScorer( - fieldEntry.similarityFunction, - OffHeapByteVectorValues.load( - fieldEntry.similarityFunction, - vectorScorer, - fieldEntry.ordToDoc, - fieldEntry.vectorEncoding, - fieldEntry.dimension, - fieldEntry.vectorDataOffset, - fieldEntry.vectorDataLength, - vectorData - ), - target - ); + throw new UnsupportedOperationException(field + " only supports float vectors"); } @Override @@ -297,6 +272,12 @@ private record FieldEntry( ) { FieldEntry { + if (vectorEncoding == VectorEncoding.BYTE) { + throw new IllegalStateException( + "Incorrect vector encoding for field=\"" + info.name + "\"; " + vectorEncoding + " not supported" + ); + } + if (similarityFunction != info.getVectorSimilarityFunction()) { throw new IllegalStateException( "Inconsistent vector similarity function for field=\"" @@ -314,10 +295,7 @@ private record FieldEntry( ); } - int byteSize = switch (info.getVectorEncoding()) { - case BYTE -> Byte.BYTES; - case FLOAT32 -> BFloat16.BYTES; - }; + int byteSize = BFloat16.BYTES; long vectorBytes = Math.multiplyExact((long) infoVectorDimension, byteSize); long numBytes = Math.multiplyExact(vectorBytes, size); if (numBytes != vectorDataLength) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java index a3ef930574aa5..b40765b051538 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java @@ -23,10 +23,8 @@ import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -157,10 +155,12 @@ public long ramBytesUsed() { private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { // write vector values - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); switch (fieldData.fieldInfo.getVectorEncoding()) { - case BYTE -> writeByteVectors(fieldData); case FLOAT32 -> writeBFloat16Vectors(fieldData); + case BYTE -> throw new IllegalStateException( + "Incorrect encoding for field " + fieldData.fieldInfo.name + ": " + VectorEncoding.BYTE + ); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; @@ -175,13 +175,6 @@ private void writeBFloat16Vectors(FieldWriter fieldData) throws IOException { } } - private void writeByteVectors(FieldWriter fieldData) throws IOException { - for (Object v : fieldData.vectors) { - byte[] vector = (byte[]) v; - vectorData.writeBytes(vector, vector.length); - } - } - private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) throws IOException { final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord @@ -190,8 +183,10 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM // write vector values long vectorDataOffset = switch (fieldData.fieldInfo.getVectorEncoding()) { - case BYTE -> writeSortedByteVectors(fieldData, ordMap); case FLOAT32 -> writeSortedBFloat16Vectors(fieldData, ordMap); + case BYTE -> throw new IllegalStateException( + "Incorrect encoding for field " + fieldData.fieldInfo.name + ": " + VectorEncoding.BYTE + ); }; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; @@ -199,7 +194,7 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM } private long writeSortedBFloat16Vectors(FieldWriter fieldData, int[] ordMap) throws IOException { - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); for (int ordinal : ordMap) { float[] vector = (float[]) fieldData.vectors.get(ordinal); @@ -209,24 +204,15 @@ private long writeSortedBFloat16Vectors(FieldWriter fieldData, int[] ordMap) return vectorDataOffset; } - private long writeSortedByteVectors(FieldWriter fieldData, int[] ordMap) throws IOException { - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); - for (int ordinal : ordMap) { - byte[] vector = (byte[]) fieldData.vectors.get(ordinal); - vectorData.writeBytes(vector, vector.length); - } - return vectorDataOffset; - } - @Override public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { // Since we know we will not be searching for additional indexing, we can just write the // the vectors directly to the new segment. - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); // No need to use temporary file as we don't have to re-open for reading DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { - case BYTE -> writeByteVectorData(vectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState)); case FLOAT32 -> writeVectorData(vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case BYTE -> throw new IllegalStateException("Incorrect encoding for field " + fieldInfo.name + ": " + VectorEncoding.BYTE); }; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); @@ -234,15 +220,15 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE @Override public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); IndexOutput tempVectorData = segmentWriteState.directory.createTempOutput(vectorData.getName(), "temp", segmentWriteState.context); IndexInput vectorDataInput = null; boolean success = false; try { // write the vector data to a temporary file DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { - case BYTE -> writeByteVectorData(tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState)); case FLOAT32 -> writeVectorData(tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case BYTE -> throw new UnsupportedOperationException("ES92BFloat16FlatVectorsWriter only supports float vectors"); }; CodecUtil.writeFooter(tempVectorData); IOUtils.close(tempVectorData); @@ -261,30 +247,17 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldI writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); success = true; final IndexInput finalVectorDataInput = vectorDataInput; - final RandomVectorScorerSupplier randomVectorScorerSupplier = switch (fieldInfo.getVectorEncoding()) { - case BYTE -> vectorsScorer.getRandomVectorScorerSupplier( - fieldInfo.getVectorSimilarityFunction(), - new OffHeapByteVectorValues.DenseOffHeapVectorValues( - fieldInfo.getVectorDimension(), - docsWithField.cardinality(), - finalVectorDataInput, - fieldInfo.getVectorDimension() * Byte.BYTES, - vectorsScorer, - fieldInfo.getVectorSimilarityFunction() - ) - ); - case FLOAT32 -> vectorsScorer.getRandomVectorScorerSupplier( - fieldInfo.getVectorSimilarityFunction(), - new OffHeapFloatVectorValues.DenseOffHeapVectorValues( - fieldInfo.getVectorDimension(), - docsWithField.cardinality(), - finalVectorDataInput, - fieldInfo.getVectorDimension() * Float.BYTES, - vectorsScorer, - fieldInfo.getVectorSimilarityFunction() - ) - ); - }; + final RandomVectorScorerSupplier randomVectorScorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapFloatVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * BFloat16.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction() + ) + ); return new FlatCloseableRandomVectorScorerSupplier(() -> { IOUtils.close(finalVectorDataInput); segmentWriteState.directory.deleteFile(tempVectorData.getName()); @@ -316,23 +289,6 @@ private void writeMeta(FieldInfo field, int maxDoc, long vectorDataOffset, long OrdToDocDISIReaderConfiguration.writeStoredMeta(DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); } - /** - * Writes the byte vector values to the output and returns a set of documents that contains - * vectors. - */ - private static DocsWithFieldSet writeByteVectorData(IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { - DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); - for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { - // write vector - byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); - assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; - output.writeBytes(binaryValue, binaryValue.length); - docsWithField.add(docV); - } - return docsWithField; - } - /** * Writes the vector values to the output and returns a set of documents that contains vectors. */ @@ -368,18 +324,13 @@ private abstract static class FieldWriter extends FlatFieldVectorsWriter { static FieldWriter create(FieldInfo fieldInfo) { int dim = fieldInfo.getVectorDimension(); return switch (fieldInfo.getVectorEncoding()) { - case BYTE -> new ES92BFloat16FlatVectorsWriter.FieldWriter(fieldInfo) { - @Override - public byte[] copyValue(byte[] value) { - return ArrayUtil.copyOfSubArray(value, 0, dim); - } - }; case FLOAT32 -> new ES92BFloat16FlatVectorsWriter.FieldWriter(fieldInfo) { @Override public float[] copyValue(float[] value) { return ArrayUtil.copyOfSubArray(value, 0, dim); } }; + case BYTE -> throw new IllegalStateException("Incorrect encoding for field " + fieldInfo.name + ": " + VectorEncoding.BYTE); }; } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java index 4bb9dcae1c116..fa3aa9c25f72b 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java @@ -29,11 +29,9 @@ import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.SoftDeletesRetentionMergePolicy; import org.apache.lucene.index.Term; @@ -63,12 +61,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexSettings; -import org.elasticsearch.index.codec.vectors.BQVectorUtils; -import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; -import org.elasticsearch.index.codec.vectors.es818.BinarizedByteVectorValues; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardPath; import org.elasticsearch.index.store.FsDirectoryFactory; @@ -85,7 +78,6 @@ import static java.lang.String.format; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -224,61 +216,6 @@ public void testVectorValuesReportCorrectDocs() throws Exception {} @Override public void testSparseVectors() throws Exception {} - public void testQuantizedVectorsWriteAndRead() throws IOException { - String fieldName = "field"; - int numVectors = random().nextInt(99, 500); - int dims = random().nextInt(4, 65); - - float[] vector = randomVector(dims); - VectorSimilarityFunction similarityFunction = randomSimilarity(); - KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); - try (Directory dir = newDirectory()) { - try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { - for (int i = 0; i < numVectors; i++) { - Document doc = new Document(); - knnField.setVectorValue(randomVector(dims)); - doc.add(knnField); - w.addDocument(doc); - if (i % 101 == 0) { - w.commit(); - } - } - w.commit(); - w.forceMerge(1); - - try (IndexReader reader = DirectoryReader.open(w)) { - LeafReader r = getOnlyLeafReader(reader); - FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertEquals(vectorValues.size(), numVectors); - BinarizedByteVectorValues qvectorValues = ((ES818BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues) - .getQuantizedVectorValues(); - float[] centroid = qvectorValues.getCentroid(); - assertEquals(centroid.length, dims); - - OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); - int[] quantizedVector = new int[dims]; - byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8]; - if (similarityFunction == VectorSimilarityFunction.COSINE) { - vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); - } - KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); - - while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { - OptimizedScalarQuantizer.QuantizationResult corrections = quantizer.scalarQuantize( - vectorValues.vectorValue(docIndexIterator.index()), - quantizedVector, - (byte) 1, - centroid - ); - BQVectorUtils.packAsBinary(quantizedVector, expectedVector); - assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); - assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); - } - } - } - } - } - public void testSimpleOffHeapSize() throws IOException { try (Directory dir = newDirectory()) { testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); From 3278721f2065fb8670a3200b69137cc7a4f76565 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 27 Aug 2025 09:50:23 +0100 Subject: [PATCH 12/17] Update for API changes --- .../es92/ES92BFloat16FlatVectorsReader.java | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java index 99d7302e4bc56..4b863d1d958c0 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java @@ -58,21 +58,24 @@ public final class ES92BFloat16FlatVectorsReader extends FlatVectorsReader { private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final IndexInput vectorData; private final FieldInfos fieldInfos; + private final IOContext dataContext; public ES92BFloat16FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) throws IOException { super(scorer); int versionMeta = readMetadata(state); this.fieldInfos = state.fieldInfos; boolean success = false; + // Flat formats are used to randomly access vectors from their node ID that is stored + // in the HNSW graph. + dataContext = + state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM); try { vectorData = openDataInput( state, versionMeta, ES92BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, ES92BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, - // Flat formats are used to randomly access vectors from their node ID that is stored - // in the HNSW graph. - state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + dataContext ); success = true; } finally { @@ -173,14 +176,10 @@ public void checkIntegrity() throws IOException { } @Override - public FlatVectorsReader getMergeInstance() { - try { - // Update the read advice since vectors are guaranteed to be accessed sequentially for merge - this.vectorData.updateReadAdvice(ReadAdvice.SEQUENTIAL); - return this; - } catch (IOException exception) { - throw new UncheckedIOException(exception); - } + public FlatVectorsReader getMergeInstance() throws IOException { + // Update the read advice since vectors are guaranteed to be accessed sequentially for merge + vectorData.updateIOContext(dataContext.withHints(DataAccessHint.SEQUENTIAL)); + return this; } private FieldEntry getFieldEntryOrThrow(String field) { @@ -252,7 +251,7 @@ public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) thr public void finishMerge() throws IOException { // This makes sure that the access pattern hint is reverted back since HNSW implementation // needs it - this.vectorData.updateReadAdvice(ReadAdvice.RANDOM); + vectorData.updateIOContext(dataContext); } @Override From 7f5ce363029429487856b2ef715ddf8d0f86f97b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 27 Aug 2025 08:57:45 +0000 Subject: [PATCH 13/17] [CI] Auto commit changes from spotless --- .../codec/vectors/es92/ES92BFloat16FlatVectorsReader.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java index 4b863d1d958c0..0dd083131b462 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java @@ -39,13 +39,11 @@ import org.apache.lucene.store.FileTypeHint; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.ReadAdvice; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.elasticsearch.core.IOUtils; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.Map; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; @@ -67,8 +65,7 @@ public ES92BFloat16FlatVectorsReader(SegmentReadState state, FlatVectorsScorer s boolean success = false; // Flat formats are used to randomly access vectors from their node ID that is stored // in the HNSW graph. - dataContext = - state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM); + dataContext = state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM); try { vectorData = openDataInput( state, From 6a8ccffa8b7871645de604e604db0952f1854866 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 27 Aug 2025 10:54:01 +0100 Subject: [PATCH 14/17] Update tests --- ...ryQuantizedBFloat16VectorsFormatTests.java | 38 +++++++++++++++ ...ryQuantizedBFloat16VectorsFormatTests.java | 47 ++++++++++++++++++- 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java index fa3aa9c25f72b..5a80207afd104 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java @@ -78,6 +78,7 @@ import static java.lang.String.format; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -193,6 +194,43 @@ public KnnVectorsFormat knnVectorsFormat() { assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); } + @Override + public void testRandomBytes() throws Exception { + // floats only + var ex = expectThrows(IllegalStateException.class, super::testRandomBytes); + assertThat(ex.getMessage(), equalTo("Incorrect encoding for field field: BYTE")); + } + + @Override + public void testSortedIndexBytes() { + // floats only + } + + @Override + public void testMergingWithDifferentByteKnnFields() { + // floats only + } + + @Override + public void testEmptyByteVectorData() { + // floats only + } + + @Override + public void testByteVectorScorerIteration() { + // floats only + } + + @Override + public void testMismatchedFields() { + // floats only + } + + @Override + public void testRandomExceptions() { + // this sometimes uses bytes - ignore + } + @Override public void testRandomWithUpdatesAndGraph() { // graph not supported diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java index 73d732e5fcc3e..d9457c943a441 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -37,6 +37,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.misc.store.DirectIODirectory; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; @@ -68,6 +69,7 @@ import static java.lang.String.format; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -120,7 +122,13 @@ public void testSingleVectorCase() throws Exception { } float[] randomVector = randomVector(vector.length); float trueScore = similarityFunction.compare(vector, randomVector); - TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE); + TopDocs td = r.searchNearestVectors( + "f", + randomVector, + 1, + AcceptDocs.fromLiveDocs(r.getLiveDocs(), r.maxDoc()), + Integer.MAX_VALUE + ); assertEquals(1, td.totalHits.value()); assertTrue(td.scoreDocs[0].score >= 0); // When it's the only vector in a segment, the score should be very close to the true score @@ -151,6 +159,43 @@ public void testVectorSimilarityFuncs() { assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); } + @Override + public void testRandomBytes() throws Exception { + // floats only + var ex = expectThrows(IllegalStateException.class, super::testRandomBytes); + assertThat(ex.getMessage(), equalTo("Incorrect encoding for field field: BYTE")); + } + + @Override + public void testSortedIndexBytes() { + // floats only + } + + @Override + public void testMergingWithDifferentByteKnnFields() { + // floats only + } + + @Override + public void testEmptyByteVectorData() { + // floats only + } + + @Override + public void testByteVectorScorerIteration() { + // floats only + } + + @Override + public void testMismatchedFields() { + // floats only + } + + @Override + public void testRandomExceptions() { + // this sometimes uses bytes - ignore + } + // bfloat16 makes the results of these tests slightly out of bounds @Override public void testWriterRamEstimate() throws Exception {} From 2da69a8da6571ef2dd246682997a778de0b93b6c Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 28 Aug 2025 10:25:12 +0100 Subject: [PATCH 15/17] Use abstract classes --- .../es92/ES92BFloat16FlatVectorsFormat.java | 17 +++-- ...2BinaryQuantizedBFloat16VectorsFormat.java | 28 +++----- ...wBinaryQuantizedBFloat16VectorsFormat.java | 70 +++---------------- 3 files changed, 24 insertions(+), 91 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java index 7097d71ecc2fe..3ea0c56670172 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java @@ -19,7 +19,6 @@ */ package org.elasticsearch.index.codec.vectors.es92; -import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; @@ -29,15 +28,15 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.MergeInfo; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; import java.util.Set; -public final class ES92BFloat16FlatVectorsFormat extends FlatVectorsFormat { +public final class ES92BFloat16FlatVectorsFormat extends AbstractFlatVectorsFormat { static final String NAME = "ES92BFloat16FlatVectorsFormat"; static final String META_CODEC_NAME = "ES92BFloat16FlatVectorsFormatMeta"; @@ -56,13 +55,18 @@ public ES92BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) { this.vectorsScorer = vectorsScorer; } + @Override + protected FlatVectorsScorer flatVectorsScorer() { + return vectorsScorer; + } + @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new ES92BFloat16FlatVectorsWriter(state, vectorsScorer); } static boolean shouldUseDirectIO(SegmentReadState state) { - return ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO && FsDirectoryFactory.isHybridFs(state.directory); + return USE_DIRECT_IO && FsDirectoryFactory.isHybridFs(state.directory); } @Override @@ -87,11 +91,6 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException } } - @Override - public String toString() { - return "ES92BFloat16FlatVectorsFormat(" + "vectorsScorer=" + vectorsScorer + ')'; - } - static class DirectIOContext implements IOContext { final Set hints; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java index c029cecb2200b..447b47ea3dbe4 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java @@ -22,9 +22,11 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; @@ -87,19 +89,10 @@ *
  • The sparse vector information, if required, mapping vector ordinal to doc ID * */ -public class ES92BinaryQuantizedBFloat16VectorsFormat extends FlatVectorsFormat { +public class ES92BinaryQuantizedBFloat16VectorsFormat extends AbstractFlatVectorsFormat { - public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; public static final String NAME = "ES92BinaryQuantizedBFloat16VectorsFormat"; - static final int VERSION_START = 0; - static final int VERSION_CURRENT = VERSION_START; - static final String META_CODEC_NAME = "ES92BinaryQuantizedBFloat16VectorsFormatMeta"; - static final String VECTOR_DATA_CODEC_NAME = "ES92BinaryQuantizedBFloat16VectorsFormatData"; - static final String META_EXTENSION = "vemb"; - static final String VECTOR_DATA_EXTENSION = "veb"; - static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; - private static final FlatVectorsFormat rawVectorFormat = new ES92BFloat16FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); @@ -113,6 +106,11 @@ public ES92BinaryQuantizedBFloat16VectorsFormat() { super(NAME); } + @Override + protected FlatVectorsScorer flatVectorsScorer() { + return scorer; + } + @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); @@ -122,14 +120,4 @@ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOExceptio public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); } - - @Override - public int getMaxDimensions(String fieldName) { - return MAX_DIMS_COUNT; - } - - @Override - public String toString() { - return "ES92BinaryQuantizedBFloat16VectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")"; - } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java index 38b54d1520f9f..2d850e0c04cae 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java @@ -19,17 +19,14 @@ */ package org.elasticsearch.index.codec.vectors.es92; -import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.search.TaskExecutor; -import org.apache.lucene.util.hnsw.HnswGraph; +import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; import java.io.IOException; import java.util.concurrent.ExecutorService; @@ -37,36 +34,17 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ -public class ES92HnswBinaryQuantizedBFloat16VectorsFormat extends KnnVectorsFormat { +public class ES92HnswBinaryQuantizedBFloat16VectorsFormat extends AbstractHnswVectorsFormat { public static final String NAME = "ES92HnswBinaryQuantizedBFloat16VectorsFormat"; - /** - * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to - * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. - */ - private final int maxConn; - - /** - * The number of candidate neighbors to track while searching the graph for each newly inserted - * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} - * for details. - */ - private final int beamWidth; - /** The format for storing, reading, merging vectors on disk */ private static final FlatVectorsFormat flatVectorsFormat = new ES92BinaryQuantizedBFloat16VectorsFormat(); - private final int numMergeWorkers; - private final TaskExecutor mergeExec; - /** Constructs a format using default graph construction parameters */ public ES92HnswBinaryQuantizedBFloat16VectorsFormat() { this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); @@ -93,28 +71,12 @@ public ES92HnswBinaryQuantizedBFloat16VectorsFormat(int maxConn, int beamWidth) * generated by this format to do the merge */ public ES92HnswBinaryQuantizedBFloat16VectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { - super(NAME); - if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { - throw new IllegalArgumentException( - "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn - ); - } - if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { - throw new IllegalArgumentException( - "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth - ); - } - this.maxConn = maxConn; - this.beamWidth = beamWidth; - if (numMergeWorkers == 1 && mergeExec != null) { - throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); - } - this.numMergeWorkers = numMergeWorkers; - if (mergeExec != null) { - this.mergeExec = new TaskExecutor(mergeExec); - } else { - this.mergeExec = null; - } + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + } + + @Override + protected FlatVectorsFormat flatVectorsFormat() { + return flatVectorsFormat; } @Override @@ -126,20 +88,4 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); } - - @Override - public int getMaxDimensions(String fieldName) { - return MAX_DIMS_COUNT; - } - - @Override - public String toString() { - return "ES92HnswBinaryQuantizedBFloat16VectorsFormat(name=ES92HnswBinaryQuantizedBFloat16VectorsFormat, maxConn=" - + maxConn - + ", beamWidth=" - + beamWidth - + ", flatVectorFormat=" - + flatVectorsFormat - + ")"; - } } From 007eb2cfd3fe3407286548a5efd97458b9b79bfb Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 28 Aug 2025 09:34:31 +0000 Subject: [PATCH 16/17] [CI] Auto commit changes from spotless --- .../vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java index 447b47ea3dbe4..322a943211320 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java @@ -34,8 +34,6 @@ import java.io.IOException; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; - /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 * Codec for encoding/decoding binary quantized vectors The binary quantization format used here From 6ebcea3bc16f7c1ea2d65c041693c8230439f5fc Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 25 Sep 2025 12:18:16 +0100 Subject: [PATCH 17/17] Fix compile --- .../codec/vectors/{es92 => }/BFloat16.java | 14 ++++++++-- .../es92/ES92BFloat16FlatVectorsReader.java | 1 + .../es92/ES92BFloat16FlatVectorsWriter.java | 1 + .../es92/OffHeapBFloat16VectorValues.java | 1 + .../index/mapper/BlockDocValuesReader.java | 28 ++++++++++++++++++- .../vectors/DenseVectorFieldMapper.java | 7 +++-- ...ryQuantizedBFloat16VectorsFormatTests.java | 1 + ...ryQuantizedBFloat16VectorsFormatTests.java | 1 + 8 files changed, 48 insertions(+), 6 deletions(-) rename server/src/main/java/org/elasticsearch/index/codec/vectors/{es92 => }/BFloat16.java (80%) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java similarity index 80% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java index 639ea6e8eed5a..f178e1e61ba5d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java @@ -7,14 +7,14 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.index.codec.vectors.es92; +package org.elasticsearch.index.codec.vectors; import org.apache.lucene.util.BitUtil; import java.nio.ByteOrder; import java.nio.ShortBuffer; -class BFloat16 { +public class BFloat16 { public static final int BYTES = Short.BYTES; @@ -47,4 +47,14 @@ public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) { floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2)); } } + + public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) { + assert floats.length == bFloats.remaining(); + assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; + for (int i = 0; i < floats.length; i++) { + floats[i] = bFloat16ToFloat(bFloats.get()); + } + } + + private BFloat16() {} } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java index 0dd083131b462..469b4f7dc98a1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java @@ -42,6 +42,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.vectors.BFloat16; import java.io.IOException; import java.util.Map; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java index b40765b051538..d3e4ea56ee5c4 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java @@ -46,6 +46,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.vectors.BFloat16; import java.io.Closeable; import java.io.IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java index f793c94d47a54..df789f2cc41a2 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java @@ -32,6 +32,7 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.elasticsearch.index.codec.vectors.BFloat16; import java.io.IOException; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 457c90383b5d2..70deb110b3a87 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -22,6 +22,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.ByteArrayStreamInput; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.mapper.BlockLoader.BlockFactory; import org.elasticsearch.index.mapper.BlockLoader.BooleanBuilder; import org.elasticsearch.index.mapper.BlockLoader.Builder; @@ -36,6 +37,9 @@ import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ShortBuffer; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX; @@ -536,7 +540,8 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { switch (fieldType.getElementType()) { - case FLOAT -> { + case FLOAT, BFLOAT16 -> { + // BFloat16 is handled by the implementation of FloatVectorValues FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); if (floatVectorValues != null) { if (fieldType.isNormalized()) { @@ -1052,6 +1057,7 @@ public AllReader reader(LeafReaderContext context) throws IOException { } return switch (elementType) { case FLOAT -> new FloatDenseVectorFromBinary(docValues, dims, indexVersion); + case BFLOAT16 -> new BFloat16DenseVectorFromBinary(docValues, dims, indexVersion); case BYTE -> new ByteDenseVectorFromBinary(docValues, dims, indexVersion); case BIT -> new BitDenseVectorFromBinary(docValues, dims, indexVersion); }; @@ -1135,6 +1141,26 @@ public String toString() { } } + private static class BFloat16DenseVectorFromBinary extends FloatDenseVectorFromBinary { + BFloat16DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + super(docValues, dims, indexVersion); + } + + @Override + protected void decodeDenseVector(BytesRef bytesRef, float[] scratch) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch); + ShortBuffer sb = ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length) + .order(ByteOrder.LITTLE_ENDIAN) + .asShortBuffer(); + BFloat16.bFloat16ToFloat(sb, scratch); + } + + @Override + public String toString() { + return "BFloat16DenseVectorFromBinary.Bytes"; + } + } + private static class ByteDenseVectorFromBinary extends AbstractDenseVectorFromBinary { ByteDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { this(docValues, dims, indexVersion, dims); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index eea7c6dea3635..6e8472d89f327 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -49,6 +49,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; @@ -1065,17 +1066,17 @@ public ElementType elementType() { @Override public void writeValue(ByteBuffer byteBuffer, float value) { - byteBuffer.putShort((short) (Float.floatToIntBits(value) >>> 16)); + byteBuffer.putShort(BFloat16.floatToBFloat16(value)); } @Override public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException { - b.value(Float.intBitsToFloat(byteBuffer.getShort() << 16)); + b.value(BFloat16.bFloat16ToFloat(byteBuffer.getShort())); } @Override public int getNumBytes(int dimensions) { - return dimensions * Short.BYTES; + return dimensions * BFloat16.BYTES; } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java index 5a80207afd104..1c88d5628f7f4 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java @@ -61,6 +61,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardPath; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java index d9457c943a441..56afb6d68fc84 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -52,6 +52,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.shard.ShardId;