From c72eda64e8970cea2540df658c27c364f684d045 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 10:33:52 +0100 Subject: [PATCH 01/46] Add BFloat16 raw vector format to bbq_hnsw and bbq_disk --- .../elasticsearch/test/knn/CmdLineArgs.java | 11 + .../index/codec/vectors/BFloat16.java | 60 +++ .../DirectIOCapableFlatVectorsFormat.java | 65 ++- .../diskbbq/ES920DiskBBQVectorsFormat.java | 18 +- ...ectIOCapableLucene99FlatVectorsFormat.java | 76 +-- .../es93/ES93BFloat16FlatVectorsFormat.java | 64 +++ .../es93/ES93BFloat16FlatVectorsReader.java | 325 +++++++++++++ .../es93/ES93BFloat16FlatVectorsWriter.java | 434 ++++++++++++++++++ .../ES93BinaryQuantizedVectorsFormat.java | 22 +- .../ES93HnswBinaryQuantizedVectorsFormat.java | 9 +- .../es93/OffHeapBFloat16VectorValues.java | 312 +++++++++++++ .../vectors/DenseVectorFieldMapper.java | 3 +- .../BFloat16RankVectorsDocValuesField.java | 157 +++++++ ...S920DiskBBQBFloat16VectorsFormatTests.java | 96 ++++ .../ES920DiskBBQVectorsFormatTests.java | 13 +- ...ryQuantizedBFloat16VectorsFormatTests.java | 97 ++++ ...ES93BinaryQuantizedVectorsFormatTests.java | 10 +- ...ryQuantizedBFloat16VectorsFormatTests.java | 110 +++++ ...HnswBinaryQuantizedVectorsFormatTests.java | 26 +- 19 files changed, 1802 insertions(+), 106 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsFormat.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsReader.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 773ad3c8da682..27272418b29f0 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -51,6 +51,7 @@ record CmdLineArgs( float filterSelectivity, long seed, VectorSimilarityFunction vectorSpace, + int rawVectorSize, int quantizeBits, VectorEncoding vectorEncoding, int dimensions, @@ -80,6 +81,7 @@ record CmdLineArgs( static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge"); static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space"); static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits"); + static final ParseField RAW_VECTOR_SIZE_FIELD = new ParseField("raw_vector_size"); static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding"); static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); @@ -123,6 +125,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD); PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD); PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD); + PARSER.declareInt(Builder::setRawVectorSize, RAW_VECTOR_SIZE_FIELD); PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD); PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); @@ -161,6 +164,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(REINDEX_FIELD.getPreferredName(), reindex); builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge); builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT)); + builder.field(RAW_VECTOR_SIZE_FIELD.getPreferredName(), rawVectorSize); builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT)); builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); @@ -196,6 +200,7 @@ static class Builder { private boolean reindex = false; private boolean forceMerge = false; private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN; + private int rawVectorSize = 32; private int quantizeBits = 8; private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; private int dimensions; @@ -305,6 +310,11 @@ public Builder setVectorSpace(String vectorSpace) { return this; } + public Builder setRawVectorSize(int rawVectorSize) { + this.rawVectorSize = rawVectorSize; + return this; + } + public Builder setQuantizeBits(int quantizeBits) { this.quantizeBits = quantizeBits; return this; @@ -380,6 +390,7 @@ public CmdLineArgs build() { filterSelectivity, seed, vectorSpace, + rawVectorSize, quantizeBits, vectorEncoding, dimensions, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java new file mode 100644 index 0000000000000..f178e1e61ba5d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.util.BitUtil; + +import java.nio.ByteOrder; +import java.nio.ShortBuffer; + +public class BFloat16 { + + public static final int BYTES = Short.BYTES; + + public static short floatToBFloat16(float f) { + // this rounds towards 0 + // zero - zero exp, zero fraction + // denormal - zero exp, non-zero fraction + // infinity - all-1 exp, zero fraction + // NaN - all-1 exp, non-zero fraction + // the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into + // infinities + return (short) (Float.floatToIntBits(f) >>> 16); + } + + public static float bFloat16ToFloat(short bf) { + return Float.intBitsToFloat(bf << 16); + } + + public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) { + assert bFloats.remaining() == floats.length; + assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; + for (float v : floats) { + bFloats.put(floatToBFloat16(v)); + } + } + + public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) { + assert floats.length * 2 == bfBytes.length; + for (int i = 0; i < floats.length; i++) { + floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2)); + } + } + + public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) { + assert floats.length == bFloats.remaining(); + assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; + for (int i = 0; i < floats.length; i++) { + floats[i] = bFloat16ToFloat(bFloats.get()); + } + } + + private BFloat16() {} +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java index 0d67281bf5606..0c92b4a8fd226 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java @@ -11,18 +11,81 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.FlushInfo; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.MergeInfo; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; +import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; +import java.util.Set; public abstract class DirectIOCapableFlatVectorsFormat extends AbstractFlatVectorsFormat { protected DirectIOCapableFlatVectorsFormat(String name) { super(name); } + protected abstract FlatVectorsReader createReader(SegmentReadState state) throws IOException; + + static boolean canUseDirectIO(SegmentReadState state) { + return FsDirectoryFactory.isHybridFs(state.directory); + } + @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { return fieldsReader(state, false); } - public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException; + public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException { + if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) { + // only override the context for the random-access use case + SegmentReadState directIOState = new SegmentReadState( + state.directory, + state.segmentInfo, + state.fieldInfos, + new DirectIOContext(state.context.hints()), + state.segmentSuffix + ); + // Use mmap for merges and direct I/O for searches. + return new MergeReaderWrapper(createReader(directIOState), createReader(state)); + } else { + return createReader(state); + } + } + + static class DirectIOContext implements IOContext { + + final Set hints; + + DirectIOContext(Set hints) { + // always add DirectIOHint to the hints given + this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE)); + } + + @Override + public Context context() { + return Context.DEFAULT; + } + + @Override + public MergeInfo mergeInfo() { + return null; + } + + @Override + public FlushInfo flushInfo() { + return null; + } + + @Override + public Set hints() { + return hints; + } + + @Override + public IOContext withHints(FileOpenHint... hints) { + return new DirectIOContext(Set.of(hints)); + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index 64796d10662e4..cee32ac4ef470 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -18,6 +18,7 @@ import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es93.DirectIOCapableLucene99FlatVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat; import java.io.IOException; import java.util.Map; @@ -58,12 +59,17 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { public static final int VERSION_DIRECT_IO = 1; public static final int VERSION_CURRENT = VERSION_DIRECT_IO; - private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); private static final Map supportedFormats = Map.of( - rawVectorFormat.getName(), - rawVectorFormat + float32VectorFormat.getName(), + float32VectorFormat, + bfloat16VectorFormat.getName(), + bfloat16VectorFormat ); // This dynamically sets the cluster probe based on the `k` requested and the number of clusters. @@ -79,12 +85,13 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private final int vectorPerCluster; private final int centroidsPerParentCluster; private final boolean useDirectIO; + private final DirectIOCapableFlatVectorsFormat rawVectorFormat; public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { - this(vectorPerCluster, centroidsPerParentCluster, false); + this(vectorPerCluster, centroidsPerParentCluster, false, false); } - public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO) { + public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO, boolean useBFloat16) { super(NAME); if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { throw new IllegalArgumentException( @@ -109,6 +116,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; this.useDirectIO = useDirectIO; + this.rawVectorFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; } /** Constructs a format using the given graph construction parameters and scalar quantization. */ diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java index a15cbba346353..42392ce406629 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java @@ -15,17 +15,9 @@ import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.store.FlushInfo; -import org.apache.lucene.store.IOContext; -import org.apache.lucene.store.MergeInfo; -import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; -import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; -import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; -import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; -import java.util.Set; public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat { @@ -45,72 +37,12 @@ protected FlatVectorsScorer flatVectorsScorer() { } @Override - public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99FlatVectorsWriter(state, vectorsScorer); - } - - static boolean canUseDirectIO(SegmentReadState state) { - return FsDirectoryFactory.isHybridFs(state.directory); + protected FlatVectorsReader createReader(SegmentReadState state) throws IOException { + return new Lucene99FlatVectorsReader(state, vectorsScorer); } @Override - public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return fieldsReader(state, false); - } - - @Override - public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException { - if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) { - // only override the context for the random-access use case - SegmentReadState directIOState = new SegmentReadState( - state.directory, - state.segmentInfo, - state.fieldInfos, - new DirectIOContext(state.context.hints()), - state.segmentSuffix - ); - // Use mmap for merges and direct I/O for searches. - return new MergeReaderWrapper( - new Lucene99FlatVectorsReader(directIOState, vectorsScorer), - new Lucene99FlatVectorsReader(state, vectorsScorer) - ); - } else { - return new Lucene99FlatVectorsReader(state, vectorsScorer); - } - } - - static class DirectIOContext implements IOContext { - - final Set hints; - - DirectIOContext(Set hints) { - // always add DirectIOHint to the hints given - this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE)); - } - - @Override - public Context context() { - return Context.DEFAULT; - } - - @Override - public MergeInfo mergeInfo() { - return null; - } - - @Override - public FlushInfo flushInfo() { - return null; - } - - @Override - public Set hints() { - return hints; - } - - @Override - public IOContext withHints(FileOpenHint... hints) { - return new DirectIOContext(Set.of(hints)); - } + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99FlatVectorsWriter(state, vectorsScorer); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsFormat.java new file mode 100644 index 0000000000000..c6b2f61a366e9 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsFormat.java @@ -0,0 +1,64 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; + +import java.io.IOException; + +public final class ES93BFloat16FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat { + + static final String NAME = "ES93BFloat16FlatVectorsFormat"; + static final String META_CODEC_NAME = "ES93BFloat16FlatVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES93BFloat16FlatVectorsFormatData"; + static final String META_EXTENSION = "vemf"; + static final String VECTOR_DATA_EXTENSION = "vec"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + private final FlatVectorsScorer vectorsScorer; + + public ES93BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) { + super(NAME); + this.vectorsScorer = vectorsScorer; + } + + @Override + protected FlatVectorsScorer flatVectorsScorer() { + return vectorsScorer; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES93BFloat16FlatVectorsWriter(state, vectorsScorer); + } + + @Override + protected FlatVectorsReader createReader(SegmentReadState state) throws IOException { + return new ES93BFloat16FlatVectorsReader(state, vectorsScorer); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsReader.java new file mode 100644 index 0000000000000..c71470d6be15e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsReader.java @@ -0,0 +1,325 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.vectors.BFloat16; + +import java.io.IOException; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +public final class ES93BFloat16FlatVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES93BFloat16FlatVectorsReader.class); + + private final IntObjectHashMap fields = new IntObjectHashMap<>(); + private final IndexInput vectorData; + private final FieldInfos fieldInfos; + private final IOContext dataContext; + + public ES93BFloat16FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; + boolean success = false; + // Flat formats are used to randomly access vectors from their node ID that is stored + // in the HNSW graph. + dataContext = state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM); + try { + vectorData = openDataInput( + state, + versionMeta, + ES93BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, + ES93BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + dataContext + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private int readMetadata(SegmentReadState state) throws IOException { + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES93BFloat16FlatVectorsFormat.META_EXTENSION + ); + int versionMeta = -1; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + meta, + ES93BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES93BFloat16FlatVectorsFormat.VERSION_START, + ES93BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + } + return versionMeta; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + ES93BFloat16FlatVectorsFormat.VERSION_START, + ES93BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = FieldEntry.create(meta, info); + fields.put(info.number, fieldEntry); + } + } + + @Override + public long ramBytesUsed() { + return ES93BFloat16FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed(); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + final FieldEntry entry = getFieldEntryOrThrow(fieldInfo.name); + return Map.of(ES93BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, entry.vectorDataLength()); + } + + @Override + public void checkIntegrity() throws IOException { + CodecUtil.checksumEntireFile(vectorData); + } + + @Override + public FlatVectorsReader getMergeInstance() throws IOException { + // Update the read advice since vectors are guaranteed to be accessed sequentially for merge + vectorData.updateIOContext(dataContext.withHints(DataAccessHint.SEQUENTIAL)); + return this; + } + + private FieldEntry getFieldEntryOrThrow(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry entry; + if (info == null || (entry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } + return entry; + } + + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldEntry fieldEntry = getFieldEntryOrThrow(field); + if (fieldEntry.vectorEncoding != expectedEncoding) { + throw new IllegalArgumentException( + "field=\"" + field + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " + expectedEncoding + ); + } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return OffHeapBFloat16VectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + throw new IllegalStateException(field + " only supports float vectors"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, + OffHeapBFloat16VectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ), + target + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + throw new UnsupportedOperationException(field + " only supports float vectors"); + } + + @Override + public void finishMerge() throws IOException { + // This makes sure that the access pattern hint is reverted back since HNSW implementation + // needs it + vectorData.updateIOContext(dataContext); + } + + @Override + public void close() throws IOException { + IOUtils.close(vectorData); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long vectorDataOffset, + long vectorDataLength, + int dimension, + int size, + OrdToDocDISIReaderConfiguration ordToDoc, + FieldInfo info + ) { + + FieldEntry { + if (vectorEncoding == VectorEncoding.BYTE) { + throw new IllegalStateException( + "Incorrect vector encoding for field=\"" + info.name + "\"; " + vectorEncoding + " not supported" + ); + } + + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + int infoVectorDimension = info.getVectorDimension(); + if (infoVectorDimension != dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + info.name + "\"; " + infoVectorDimension + " != " + dimension + ); + } + + int byteSize = BFloat16.BYTES; + long vectorBytes = Math.multiplyExact((long) infoVectorDimension, byteSize); + long numBytes = Math.multiplyExact(vectorBytes, size); + if (numBytes != vectorDataLength) { + throw new IllegalStateException( + "Vector data length " + + vectorDataLength + + " not matching size=" + + size + + " * dim=" + + dimension + + " * byteSize=" + + byteSize + + " = " + + numBytes + ); + } + } + + static FieldEntry create(IndexInput input, FieldInfo info) throws IOException { + final VectorEncoding vectorEncoding = readVectorEncoding(input); + final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + final var vectorDataOffset = input.readVLong(); + final var vectorDataLength = input.readVLong(); + final var dimension = input.readVInt(); + final var size = input.readInt(); + final var ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry(similarityFunction, vectorEncoding, vectorDataOffset, vectorDataLength, dimension, size, ordToDoc, info); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java new file mode 100644 index 0000000000000..3c143d94fd6b5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java @@ -0,0 +1,434 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.vectors.BFloat16; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; + +public final class ES93BFloat16FlatVectorsWriter extends FlatVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ES93BFloat16FlatVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final IndexOutput meta, vectorData; + + private final List> fields = new ArrayList<>(); + private boolean finished; + + public ES93BFloat16FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES93BFloat16FlatVectorsFormat.META_EXTENSION + ); + + String vectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES93BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION + ); + + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES93BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES93BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + vectorData, + ES93BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + ES93BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FieldWriter newField = FieldWriter.create(fieldInfo); + fields.add(newField); + return newField; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + for (FieldWriter field : fields) { + if (sortMap == null) { + writeField(field, maxDoc); + } else { + writeSortingField(field, maxDoc, sortMap); + } + field.finish(); + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + total += field.ramBytesUsed(); + } + return total; + } + + private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + switch (fieldData.fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeBFloat16Vectors(fieldData); + case BYTE -> throw new IllegalStateException( + "Incorrect encoding for field " + fieldData.fieldInfo.name + ": " + VectorEncoding.BYTE + ); + } + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField); + } + + private void writeBFloat16Vectors(FieldWriter fieldData) throws IOException { + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (Object v : fieldData.vectors) { + BFloat16.floatToBFloat16((float[]) v, buffer.asShortBuffer()); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + } + + private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) throws IOException { + final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = switch (fieldData.fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeSortedBFloat16Vectors(fieldData, ordMap); + case BYTE -> throw new IllegalStateException( + "Incorrect encoding for field " + fieldData.fieldInfo.name + ": " + VectorEncoding.BYTE + ); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField); + } + + private long writeSortedBFloat16Vectors(FieldWriter fieldData, int[] ordMap) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] vector = (float[]) fieldData.vectors.get(ordinal); + BFloat16.floatToBFloat16(vector, buffer.asShortBuffer()); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + return vectorDataOffset; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + // Since we know we will not be searching for additional indexing, we can just write the + // the vectors directly to the new segment. + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + // No need to use temporary file as we don't have to re-open for reading + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeVectorData(vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case BYTE -> throw new IllegalStateException("Incorrect encoding for field " + fieldInfo.name + ": " + VectorEncoding.BYTE); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + IndexOutput tempVectorData = segmentWriteState.directory.createTempOutput(vectorData.getName(), "temp", segmentWriteState.context); + IndexInput vectorDataInput = null; + boolean success = false; + try { + // write the vector data to a temporary file + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeVectorData(tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case BYTE -> throw new UnsupportedOperationException("ES92BFloat16FlatVectorsWriter only supports float vectors"); + }; + CodecUtil.writeFooter(tempVectorData); + IOUtils.close(tempVectorData); + + // This temp file will be accessed in a random-access fashion to construct the HNSW graph. + // Note: don't use the context from the state, which is a flush/merge context, not expecting + // to perform random reads. + vectorDataInput = segmentWriteState.directory.openInput( + tempVectorData.getName(), + IOContext.DEFAULT.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + ); + // copy the temporary file vectors to the actual data file + vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength()); + CodecUtil.retrieveChecksum(vectorDataInput); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + success = true; + final IndexInput finalVectorDataInput = vectorDataInput; + final RandomVectorScorerSupplier randomVectorScorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapFloatVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * BFloat16.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction() + ) + ); + return new FlatCloseableRandomVectorScorerSupplier(() -> { + IOUtils.close(finalVectorDataInput); + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + }, docsWithField.cardinality(), randomVectorScorerSupplier); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(vectorDataInput, tempVectorData); + try { + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + } catch (Exception e) { + // ignore + } + } + } + } + + private void writeMeta(FieldInfo field, int maxDoc, long vectorDataOffset, long vectorDataLength, DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + meta.writeVInt(field.getVectorDimension()); + + // write docIDs + int count = docsWithField.cardinality(); + meta.writeInt(count); + OrdToDocDISIReaderConfiguration.writeStoredMeta(DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); + } + + /** + * Writes the vector values to the output and returns a set of documents that contains vectors. + */ + private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { + // write vector + float[] value = floatVectorValues.vectorValue(iter.index()); + BFloat16.floatToBFloat16(value, buffer.asShortBuffer()); + output.writeBytes(buffer.array(), buffer.limit()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData); + } + + private abstract static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private final int dim; + private final DocsWithFieldSet docsWithField; + private final List vectors; + private boolean finished; + + private int lastDocID = -1; + + static FieldWriter create(FieldInfo fieldInfo) { + int dim = fieldInfo.getVectorDimension(); + return switch (fieldInfo.getVectorEncoding()) { + case FLOAT32 -> new ES93BFloat16FlatVectorsWriter.FieldWriter(fieldInfo) { + @Override + public float[] copyValue(float[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + case BYTE -> throw new IllegalStateException("Incorrect encoding for field " + fieldInfo.name + ": " + VectorEncoding.BYTE); + }; + } + + FieldWriter(FieldInfo fieldInfo) { + super(); + this.fieldInfo = fieldInfo; + this.dim = fieldInfo.getVectorDimension(); + this.docsWithField = new DocsWithFieldSet(); + vectors = new ArrayList<>(); + } + + @Override + public void addValue(int docID, T vectorValue) throws IOException { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)" + ); + } + assert docID > lastDocID; + T copy = copyValue(vectorValue); + docsWithField.add(docID); + vectors.add(copy); + lastDocID = docID; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_RAM_BYTES_USED; + if (vectors.size() == 0) return size; + + int byteSize = fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32 + ? BFloat16.BYTES + : fieldInfo.getVectorEncoding().byteSize; + + return size + docsWithField.ramBytesUsed() + (long) vectors.size() * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) vectors.size() * fieldInfo.getVectorDimension() * byteSize; + } + + @Override + public List getVectors() { + return vectors; + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + this.finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } + } + + static final class FlatCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + + private final RandomVectorScorerSupplier supplier; + private final Closeable onClose; + private final int numVectors; + + FlatCloseableRandomVectorScorerSupplier(Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) { + this.onClose = onClose; + this.supplier = supplier; + this.numVectors = numVectors; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return numVectors; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java index 69f4f96a4e829..2d5592314f198 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. * - * Modifications copyright (C) 2024 Elasticsearch B.V. + * Modifications copyright (C) 2025 Elasticsearch B.V. */ package org.elasticsearch.index.codec.vectors.es93; @@ -91,13 +91,18 @@ public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsForm public static final String NAME = "ES93BinaryQuantizedVectorsFormat"; - private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); private static final Map supportedFormats = Map.of( - rawVectorFormat.getName(), - rawVectorFormat + float32VectorFormat.getName(), + float32VectorFormat, + bfloat16VectorFormat.getName(), + bfloat16VectorFormat ); private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( @@ -105,15 +110,16 @@ public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsForm ); private final boolean useDirectIO; + private final DirectIOCapableFlatVectorsFormat rawFormat; public ES93BinaryQuantizedVectorsFormat() { - super(NAME); - this.useDirectIO = false; + this(false, false); } - public ES93BinaryQuantizedVectorsFormat(boolean useDirectIO) { + public ES93BinaryQuantizedVectorsFormat(boolean useDirectIO, boolean useBFloat16) { super(NAME); this.useDirectIO = useDirectIO; + this.rawFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; } @Override @@ -128,7 +134,7 @@ protected boolean useDirectIOReads() { @Override protected DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat() { - return rawVectorFormat; + return rawFormat; } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index b1ade1524e250..579c42edc6288 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. * - * Modifications copyright (C) 2024 Elasticsearch B.V. + * Modifications copyright (C) 2025 Elasticsearch B.V. */ package org.elasticsearch.index.codec.vectors.es93; @@ -51,9 +51,9 @@ public ES93HnswBinaryQuantizedVectorsFormat() { * @param beamWidth the size of the queue maintained during graph construction. * @param useDirectIO whether to use direct IO when reading raw vectors */ - public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useDirectIO) { + public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useDirectIO, boolean useBFloat16) { super(NAME, maxConn, beamWidth); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO, useBFloat16); } /** @@ -71,11 +71,12 @@ public ES93HnswBinaryQuantizedVectorsFormat( int maxConn, int beamWidth, boolean useDirectIO, + boolean useBFloat16, int numMergeWorkers, ExecutorService mergeExec ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO, useBFloat16); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java new file mode 100644 index 0000000000000..2038cb5232666 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java @@ -0,0 +1,312 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.elasticsearch.index.codec.vectors.BFloat16; + +import java.io.IOException; + +abstract class OffHeapBFloat16VectorValues extends FloatVectorValues implements HasIndexSlice { + + protected final int dimension; + protected final int size; + protected final IndexInput slice; + protected final int byteSize; + protected int lastOrd = -1; + protected final byte[] bfloatBytes; + protected final float[] value; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer flatVectorsScorer; + + OffHeapBFloat16VectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + this.dimension = dimension; + this.size = size; + this.slice = slice; + this.byteSize = byteSize; + this.similarityFunction = similarityFunction; + this.flatVectorsScorer = flatVectorsScorer; + bfloatBytes = new byte[dimension * 2]; + value = new float[dimension]; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return value; + } + slice.seek((long) targetOrd * byteSize); + // no readShorts() method + slice.readBytes(bfloatBytes, 0, bfloatBytes.length); + BFloat16.bFloat16ToFloat(bfloatBytes, value); + lastOrd = targetOrd; + return value; + } + + static OffHeapBFloat16VectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, + OrdToDocDISIReaderConfiguration configuration, + VectorEncoding vectorEncoding, + int dimension, + int size, + long vectorDataOffset, + long vectorDataLength, + IndexInput vectorData + ) throws IOException { + if (configuration.isEmpty() || vectorEncoding != VectorEncoding.FLOAT32) { + return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); + } + IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); + int byteSize = dimension * BFloat16.BYTES; + if (configuration.isDense()) { + return new DenseOffHeapVectorValues(dimension, size, bytesSlice, byteSize, flatVectorsScorer, vectorSimilarityFunction); + } else { + return new SparseOffHeapVectorValues( + configuration, + vectorData, + bytesSlice, + dimension, + size, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction + ); + } + } + + /** + * Dense vector values that are stored off-heap. This is the most common case when every doc has a + * vector. + */ + static class DenseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public int ordToDoc(int ord) { + return ord; + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.docID()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class SparseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + IndexInput dataIn, + IndexInput slice, + int dimension, + int size, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) throws IOException { + + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dataIn, + slice.clone(), + dimension, + size, + byteSize, + flatVectorsScorer, + similarityFunction + ); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + EmptyOffHeapVectorValues(int dimension, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { + super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); + } + + @Override + public int dimension() { + return super.dimension(); + } + + @Override + public int size() { + return 0; + } + + @Override + public EmptyOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] vectorValue(int targetOrd) { + throw new UnsupportedOperationException(); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 33d99c5628732..c40986be0ffbf 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2166,7 +2166,8 @@ KnnVectorsFormat getVectorsFormat(ElementType elementType) { return new ES920DiskBBQVectorsFormat( clusterSize, ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - onDiskRescore + onDiskRescore, + false ); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java new file mode 100644 index 0000000000000..48b44df9732fd --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java @@ -0,0 +1,157 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ShortBuffer; +import java.util.Iterator; + +public class BFloat16RankVectorsDocValuesField extends RankVectorsDocValuesField { + + private final BinaryDocValues input; + private final BinaryDocValues magnitudes; + private boolean decoded; + private final int dims; + private BytesRef value; + private BytesRef magnitudesValue; + private BFloat16VectorIterator vectorValues; + private int numVectors; + private float[] buffer; + + public BFloat16RankVectorsDocValuesField( + BinaryDocValues input, + BinaryDocValues magnitudes, + String name, + ElementType elementType, + int dims + ) { + super(name, elementType); + this.input = input; + this.magnitudes = magnitudes; + this.dims = dims; + this.buffer = new float[dims]; + } + + @Override + public void setNextDocId(int docId) throws IOException { + decoded = false; + if (input.advanceExact(docId)) { + boolean magnitudesFound = magnitudes.advanceExact(docId); + assert magnitudesFound; + + value = input.binaryValue(); + assert value.length % (BFloat16.BYTES * dims) == 0; + numVectors = value.length / (BFloat16.BYTES * dims); + magnitudesValue = magnitudes.binaryValue(); + assert magnitudesValue.length == (Float.BYTES * numVectors); + } else { + value = null; + magnitudesValue = null; + numVectors = 0; + } + } + + @Override + public RankVectorsScriptDocValues toScriptDocValues() { + return new RankVectorsScriptDocValues(this, dims); + } + + @Override + public boolean isEmpty() { + return value == null; + } + + @Override + public RankVectors get() { + if (isEmpty()) { + return RankVectors.EMPTY; + } + decodeVectorIfNecessary(); + return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); + } + + @Override + public RankVectors get(RankVectors defaultValue) { + if (isEmpty()) { + return defaultValue; + } + decodeVectorIfNecessary(); + return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); + } + + @Override + public RankVectors getInternal() { + return get(null); + } + + @Override + public int size() { + return value == null ? 0 : value.length / (BFloat16.BYTES * dims); + } + + private void decodeVectorIfNecessary() { + if (decoded == false && value != null) { + vectorValues = new BFloat16VectorIterator(value, buffer, numVectors); + decoded = true; + } + } + + public static class BFloat16VectorIterator implements VectorIterator { + private final float[] buffer; + private final ShortBuffer vectorValues; + private final BytesRef vectorValueBytesRef; + private final int size; + private int idx = 0; + + public BFloat16VectorIterator(BytesRef vectorValues, float[] buffer, int size) { + assert vectorValues.length == (buffer.length * BFloat16.BYTES * size); + this.vectorValueBytesRef = vectorValues; + this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length) + .order(ByteOrder.LITTLE_ENDIAN) + .asShortBuffer(); + this.size = size; + this.buffer = buffer; + } + + @Override + public boolean hasNext() { + return idx < size; + } + + @Override + public float[] next() { + if (hasNext() == false) { + throw new IllegalArgumentException("No more elements in the iterator"); + } + BFloat16.bFloat16ToFloat(vectorValues, buffer); + idx++; + return buffer; + } + + @Override + public Iterator copy() { + return new BFloat16VectorIterator(vectorValueBytesRef, new float[buffer.length], size); + } + + @Override + public void reset() { + idx = 0; + vectorValues.rewind(); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..38548deff5b45 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.closeTo; + +public class ES920DiskBBQBFloat16VectorsFormatTests extends ES920DiskBBQVectorsFormatTests { + @Override + boolean useBFloat16() { + return true; + } + + @Override + public void testEmptyByteVectorData() throws Exception { + // no bytes + } + + @Override + public void testMergingWithDifferentByteKnnFields() throws Exception { + // no bytes + } + + @Override + public void testByteVectorScorerIteration() throws Exception { + // no bytes + } + + @Override + public void testSortedIndexBytes() throws Exception { + // no bytes + } + + @Override + public void testMismatchedFields() throws Exception { + // no bytes + } + + @Override + public void testRandomBytes() throws Exception { + // no bytes + } + + @Override + public void testWriterRamEstimate() throws Exception { + // estimate is different due to bfloat16 + } + + @Override + public void testRandom() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandom); + assertFloatsWithinBounds(err); + } + + @Override + public void testSparseVectors() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); + assertFloatsWithinBounds(err); + } + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); + assertFloatsWithinBounds(err); + } + + @Override + public void testRandomWithUpdatesAndGraph() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); + assertFloatsWithinBounds(err); + } + + private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); + + private static void assertFloatsWithinBounds(AssertionError error) { + Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); + if (m.matches() == false) { + throw error; // nothing to do with us, just rethrow + } + + // numbers just need to be in the same vicinity + double expected = Double.parseDouble(m.group(1)); + double actual = Double.parseDouble(m.group(2)); + double allowedError = expected * 0.01; // within 1% + assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index f3cd4f92a6a87..1535f71e3ba54 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -60,7 +60,12 @@ public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase LogConfigurator.loadLog4jPlugins(); LogConfigurator.configureESLogging(); // native access requires logging to be initialized } - KnnVectorsFormat format; + + private KnnVectorsFormat format; + + boolean useBFloat16() { + return false; + } @Before @Override @@ -69,14 +74,16 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), - random().nextBoolean() + random().nextBoolean(), + useBFloat16() ); } else { // run with low numbers to force many clusters with parents format = new ES920DiskBBQVectorsFormat( random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), - random().nextBoolean() + random().nextBoolean(), + useBFloat16() ); } super.setUp(); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..9ae394733631e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.index.VectorEncoding; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.closeTo; + +public class ES93BinaryQuantizedBFloat16VectorsFormatTests extends ES93BinaryQuantizedVectorsFormatTests { + @Override + boolean useBFloat16() { + return true; + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + public void testEmptyByteVectorData() throws Exception { + // no bytes + } + + @Override + public void testMergingWithDifferentByteKnnFields() throws Exception { + // no bytes + } + + @Override + public void testByteVectorScorerIteration() throws Exception { + // no bytes + } + + @Override + public void testSortedIndexBytes() throws Exception { + // no bytes + } + + @Override + public void testMismatchedFields() throws Exception { + // no bytes + } + + @Override + public void testRandomBytes() throws Exception { + // no bytes + } + + @Override + public void testWriterRamEstimate() throws Exception { + // estimate is different due to bfloat16 + } + + @Override + public void testRandom() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandom); + assertFloatsWithinBounds(err); + } + + @Override + public void testSparseVectors() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); + assertFloatsWithinBounds(err); + } + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); + assertFloatsWithinBounds(err); + } + + private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); + + private static void assertFloatsWithinBounds(AssertionError error) { + Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); + if (m.matches() == false) { + throw error; // nothing to do with us, just rethrow + } + + // numbers just need to be in the same vicinity + double expected = Double.parseDouble(m.group(1)); + double actual = Double.parseDouble(m.group(2)); + double allowedError = expected * 0.01; // within 1% + assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java index 96538fd8dfb74..20689e773ee79 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -56,6 +56,7 @@ import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BFloat16; import java.io.IOException; import java.util.ArrayList; @@ -77,9 +78,13 @@ public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatT private KnnVectorsFormat format; + boolean useBFloat16() { + return false; + } + @Override public void setUp() throws Exception { - format = new ES93BinaryQuantizedVectorsFormat(random().nextBoolean()); + format = new ES93BinaryQuantizedVectorsFormat(random().nextBoolean(), useBFloat16()); super.setUp(); } @@ -239,7 +244,8 @@ public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, b assertEquals(expectVecOffHeap ? 2 : 1, offHeap.size()); assertTrue(offHeap.get("veb") > 0L); if (expectVecOffHeap) { - assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); + int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; + assertEquals(vector.length * bytes, (long) offHeap.get("vec")); } } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..c6f3e555013b3 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.index.VectorEncoding; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.closeTo; + +public class ES93HnswBinaryQuantizedBFloat16VectorsFormatTests extends ES93HnswBinaryQuantizedVectorsFormatTests { + + @Override + boolean useBFloat16() { + return true; + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + public void testEmptyByteVectorData() throws Exception { + // no bytes + } + + @Override + public void testMergingWithDifferentByteKnnFields() throws Exception { + // no bytes + } + + @Override + public void testByteVectorScorerIteration() throws Exception { + // no bytes + } + + @Override + public void testSortedIndexBytes() throws Exception { + // no bytes + } + + @Override + public void testMismatchedFields() throws Exception { + // no bytes + } + + @Override + public void testRandomBytes() throws Exception { + // no bytes + } + + @Override + public void testWriterRamEstimate() throws Exception { + // estimate is different due to bfloat16 + } + + @Override + public void testSingleVectorCase() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSingleVectorCase); + assertFloatsWithinBounds(err); + } + + @Override + public void testRandom() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandom); + assertFloatsWithinBounds(err); + } + + @Override + public void testRandomWithUpdatesAndGraph() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); + assertFloatsWithinBounds(err); + } + + @Override + public void testSparseVectors() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); + assertFloatsWithinBounds(err); + } + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); + assertFloatsWithinBounds(err); + } + + private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); + + private static void assertFloatsWithinBounds(AssertionError error) { + Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); + if (m.matches() == false) { + throw error; // nothing to do with us, just rethrow + } + + // numbers just need to be in the same vicinity + double expected = Double.parseDouble(m.group(1)); + double actual = Double.parseDouble(m.group(2)); + double allowedError = expected * 0.01; // within 1% + assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index 809436f139573..c5d85b5cc4681 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -46,6 +46,7 @@ import org.apache.lucene.util.SameThreadExecutorService; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BFloat16; import java.io.IOException; import java.util.Arrays; @@ -68,9 +69,13 @@ public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFor private KnnVectorsFormat format; + boolean useBFloat16() { + return false; + } + @Override public void setUp() throws Exception { - format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, random().nextBoolean()); + format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, random().nextBoolean(), useBFloat16()); super.setUp(); } @@ -83,7 +88,7 @@ public void testToString() { FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { @Override public KnnVectorsFormat knnVectorsFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, false, 1, null); + return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, false, false, 1, null); } }; String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," @@ -137,15 +142,15 @@ public void testSingleVectorCase() throws Exception { } public void testLimits() { - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, false, false)); expectThrows( IllegalArgumentException.class, - () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, false, 1, new SameThreadExecutorService()) + () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService()) ); } @@ -189,7 +194,8 @@ public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, b assertEquals(1L, (long) offHeap.get("vex")); assertTrue(offHeap.get("veb") > 0L); if (expectVecOffHeap) { - assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); + int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; + assertEquals(vector.length * bytes, (long) offHeap.get("vec")); } } } From b4c682acb6d869c20bbe43b7e03313939ba5aff7 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 11:07:32 +0100 Subject: [PATCH 02/46] Remove tripping assertion --- .../index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index fe26de0fe869c..4884572d99fc6 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -70,7 +70,6 @@ public RandomVectorScorer getRandomVectorScorer( assert binarizedVectors.size() > 0 : "BinarizedByteVectorValues must have at least one vector for ES816BinaryFlatVectorsScorer"; OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer(); float[] centroid = binarizedVectors.getCentroid(); - assert similarityFunction != COSINE || VectorUtil.isUnitVector(target); float[] scratch = new float[vectorValues.dimension()]; int[] initial = new int[target.length]; byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; From 803dfc15de49d45f340c2c7955abe9d3cb3e3558 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 11:12:48 +0100 Subject: [PATCH 03/46] Enable direct IO and bfloat16 --- .../search.vectors/41_knn_search_bbq_hnsw.yml | 62 ++ .../41_knn_search_bbq_hnsw_bfloat16.yml | 580 ++++++++++++++++++ .../42_knn_search_bbq_flat_bfloat16.yml | 512 ++++++++++++++++ .../elasticsearch/index/store/DirectIOIT.java | 2 +- .../index/mapper/BlockDocValuesReader.java | 28 +- .../index/mapper/MapperFeatures.java | 4 +- .../vectors/DenseVectorFieldMapper.java | 107 +++- .../mapper/vectors/VectorDVLeafFieldData.java | 6 +- .../script/VectorScoreScriptUtils.java | 8 +- .../BFloat16RankVectorsDocValuesField.java | 4 +- .../DenseVectorFieldMapperTestUtils.java | 6 +- .../vectors/DenseVectorFieldMapperTests.java | 16 +- .../vectors/DenseVectorFieldTypeTests.java | 8 +- ...AbstractKnnVectorQueryBuilderTestCase.java | 4 +- .../TestDenseInferenceServiceExtension.java | 2 +- .../mapper/SemanticTextFieldMapper.java | 2 +- .../elastic/ElasticTextEmbeddingPayload.java | 2 +- ...cInferenceMetadataFieldsRecoveryTests.java | 2 +- .../mapper/SemanticTextFieldMapperTests.java | 2 +- .../mapper/SemanticTextFieldTests.java | 8 +- .../queries/SemanticQueryBuilderTests.java | 2 +- .../mapper/RankVectorsDVLeafFieldData.java | 45 +- .../script/RankVectorsScoreScriptUtils.java | 2 +- .../mapper/RankVectorsFieldMapperTests.java | 4 +- .../RankVectorsScriptDocValuesTests.java | 122 +++- 25 files changed, 1469 insertions(+), 71 deletions(-) create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index e3c1155ed2000..b58ae2a29839a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -338,6 +338,68 @@ setup: - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } --- +"Test index configured rescore vector with on-disk rescoring": + - requires: + cluster_features: ["mapper.vectors.bbq_hnsw_on_disk_rescoring"] + reason: Needs on_disk_rescoring feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_on_disk_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + on_disk_rescore: true + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_on_disk_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_on_disk_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } +--- "Test index configured rescore vector updateable and settable to 0": - requires: cluster_features: ["mapper.dense_vector.rescore_zero_vector"] diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml new file mode 100644 index 0000000000000..980d9ba924fea --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -0,0 +1,580 @@ +setup: + - requires: + cluster_features: "mapper.vectors.bbq_hnsw_on_disk_rescoring" + reason: 'bfloat16 needs to be supported' + - do: + indices.create: + index: bbq_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "3" + body: + name: rabbit.jpg + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + indices.forcemerge: + index: bbq_hnsw + max_num_segments: 1 + + - do: + indices.refresh: { } +--- +"Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" + - do: + search: + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad quantization parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: false + index_options: + type: bbq_hnsw +--- +"Test few dimensions fail indexing": + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 42 + index: true + index_options: + type: bbq_hnsw + + - do: + indices.create: + index: dynamic_dim_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index: true + similarity: l2_norm + index_options: + type: bbq_hnsw + + - do: + catch: bad_request + index: + index: dynamic_dim_bbq_hnsw + body: + vector: [1.0, 2.0, 3.0, 4.0, 5.0] + + - do: + index: + index: dynamic_dim_bbq_hnsw + body: + vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } +--- +"Test index configured rescore vector updateable and settable to 0": + - requires: + cluster_features: ["mapper.dense_vector.rescore_zero_vector"] + reason: Needs rescore_zero_vector feature + + - do: + indices.create: + index: bbq_rescore_0_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + indices.create: + index: bbq_rescore_update_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 1 + + - do: + indices.put_mapping: + index: bbq_rescore_update_hnsw + body: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + indices.get_mapping: + index: bbq_rescore_update_hnsw + + - match: { .bbq_rescore_update_hnsw.mappings.properties.vector.index_options.rescore_vector.oversample: 0 } +--- +"Test index configured rescore vector score consistency": + - requires: + cluster_features: ["mapper.dense_vector.rescore_zero_vector"] + reason: Needs rescore_zero_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_zero_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + bulk: + index: bbq_rescore_zero_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: raw_score0 } + - set: { hits.hits.1._score: raw_score1 } + - set: { hits.hits.2._score: raw_score2 } + + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 2 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: override_score0 } + - set: { hits.hits.1._score: override_score1 } + - set: { hits.hits.2._score: override_score2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_hnsw + body: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 2 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: default_rescore0 } + - set: { hits.hits.1._score: default_rescore1 } + - set: { hits.hits.2._score: default_rescore2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_hnsw + body: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $override_score0 } + - match: { hits.hits.0._score: $default_rescore0 } + - match: { hits.hits.1._score: $override_score1 } + - match: { hits.hits.1._score: $default_rescore1 } + - match: { hits.hits.2._score: $override_score2 } + - match: { hits.hits.2._score: $default_rescore2 } + +--- +"default oversample value": + - requires: + cluster_features: ["mapper.dense_vector.default_oversample_value_for_bbq"] + reason: "Needs default_oversample_value_for_bbq feature" + - do: + indices.get_mapping: + index: bbq_hnsw + + - match: { bbq_hnsw.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml new file mode 100644 index 0000000000000..2b801e92d7b7c --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -0,0 +1,512 @@ +setup: + - requires: + cluster_features: "mapper.vectors.bbq_hnsw_on_disk_rescoring" + reason: 'bfloat16 needs to be supported' + - do: + indices.create: + index: bbq_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + + - do: + index: + index: bbq_flat + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + index: + index: bbq_flat + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + index: + index: bbq_flat + id: "3" + body: + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + indices.forcemerge: + index: bbq_flat + max_num_segments: 1 +--- +"Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" + - do: + search: + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + query: + script_score: + query: { match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + index_options: + type: bbq_flat + m: 42 +--- +"Test bad raw vector size": + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + index_options: + type: bbq_flat + raw_vector_size: 25 +--- +"Test few dimensions fail indexing": + # verify index creation fails + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 42 + index: true + similarity: l2_norm + index_options: + type: bbq_flat + + # verify dynamic dimension fails + - do: + indices.create: + index: dynamic_dim_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index: true + similarity: l2_norm + index_options: + type: bbq_flat + + # verify index fails for odd dim vector + - do: + catch: bad_request + index: + index: dynamic_dim_bbq_flat + body: + vector: [1.0, 2.0, 3.0, 4.0, 5.0] + + # verify that we can index an even dim vector after the odd dim vector failure + - do: + index: + index: dynamic_dim_bbq_flat + body: + vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"default oversample value": + - requires: + cluster_features: ["mapper.dense_vector.default_oversample_value_for_bbq"] + reason: "Needs default_oversample_value_for_bbq feature" + - do: + indices.get_mapping: + index: bbq_flat + + - match: { bbq_flat.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } +--- +"Test nested queries": + - do: + indices.create: + index: bbq_flat_nested + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + name: + type: keyword + nested: + type: nested + properties: + paragraph_id: + type: keyword + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + + - do: + index: + index: bbq_flat_nested + id: "1" + body: + nested: + - paragraph_id: "1" + vector: [ 0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45 ] + - paragraph_id: "2" + vector: [ 0.7, 0.2 , 0.205, 0.63 , 0.032, 0.201, 0.167, 0.313, + 0.176, 0.1, 0.375, 0.334, 0.046, 0.078, 0.349, 0.272, + 0.307, 0.083, 0.504, 0.255, 0.404, 0.289, 0.226, 0.132, + 0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , 0.265, + 0.285, 0.336, 0.272, 0.369, -0.282, 0.086, 0.132, 0.475, + 0.224, 0.203, 0.439, 0.064, 0.246, 0.396, 0.297, 0.242, + 0.224, 0.203, 0.439, 0.064, 0.246, 0.396, 0.297, 0.242, + 0.028, 0.321, 0.022, 0.009, 0.001 , 0.031, -0.533, 0.45] + - do: + index: + index: bbq_flat_nested + id: "2" + body: + nested: + - paragraph_id: 0 + vector: [ 0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27, -0.013 ] + - paragraph_id: 2 + vector: [ 0.196, 0.514, 0.039, 0.555, 0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, 0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, 0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, 0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, 0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27, 0.013 ] + - paragraph_id: 3 + vector: [ 0.196, 0.514, 0.039, 0.555, 0.042, 0.242, 0.463, -0.348, + 0.08 , 0.442, -0.067, -0.05 , 0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, 0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, 0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, 0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, 0.209, -0.153, -0.27, -0.013 ] + + - do: + index: + index: bbq_flat_nested + id: "3" + body: + nested: + - paragraph_id: 0 + vector: [ 0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058 ] + + - do: + indices.flush: + index: bbq_flat_nested + + - do: + indices.forcemerge: + index: bbq_flat_nested + max_num_segments: 1 + + - do: + search: + index: bbq_flat_nested + body: + query: + nested: + path: nested + query: + knn: + field: nested.vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + num_candidates: 3 + k: 2 + + - match: {hits.hits.0._id: "3"} + + - do: + search: + index: bbq_flat_nested + body: + knn: + field: nested.vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + num_candidates: 3 + k: 2 + + - match: {hits.hits.0._id: "3"} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java index efbc19b30079c..ae888f50155cd 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java @@ -73,7 +73,7 @@ protected boolean useDirectIO(String name, IOContext context, OptionalLong fileL @ParametersFactory public static Iterable parameters() { - return List.of(new Object[] { "bbq_disk" }); + return List.of(new Object[] { "bbq_hnsw" }, new Object[] { "bbq_disk" }); } public DirectIOIT(String type) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 457c90383b5d2..70deb110b3a87 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -22,6 +22,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.ByteArrayStreamInput; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.mapper.BlockLoader.BlockFactory; import org.elasticsearch.index.mapper.BlockLoader.BooleanBuilder; import org.elasticsearch.index.mapper.BlockLoader.Builder; @@ -36,6 +37,9 @@ import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ShortBuffer; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX; @@ -536,7 +540,8 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { switch (fieldType.getElementType()) { - case FLOAT -> { + case FLOAT, BFLOAT16 -> { + // BFloat16 is handled by the implementation of FloatVectorValues FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); if (floatVectorValues != null) { if (fieldType.isNormalized()) { @@ -1052,6 +1057,7 @@ public AllReader reader(LeafReaderContext context) throws IOException { } return switch (elementType) { case FLOAT -> new FloatDenseVectorFromBinary(docValues, dims, indexVersion); + case BFLOAT16 -> new BFloat16DenseVectorFromBinary(docValues, dims, indexVersion); case BYTE -> new ByteDenseVectorFromBinary(docValues, dims, indexVersion); case BIT -> new BitDenseVectorFromBinary(docValues, dims, indexVersion); }; @@ -1135,6 +1141,26 @@ public String toString() { } } + private static class BFloat16DenseVectorFromBinary extends FloatDenseVectorFromBinary { + BFloat16DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + super(docValues, dims, indexVersion); + } + + @Override + protected void decodeDenseVector(BytesRef bytesRef, float[] scratch) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch); + ShortBuffer sb = ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length) + .order(ByteOrder.LITTLE_ENDIAN) + .asShortBuffer(); + BFloat16.bFloat16ToFloat(sb, scratch); + } + + @Override + public String toString() { + return "BFloat16DenseVectorFromBinary.Bytes"; + } + } + private static class ByteDenseVectorFromBinary extends AbstractDenseVectorFromBinary { ByteDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { this(docValues, dims, indexVersion, dims); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index e256453ac6d31..9b0300d21e148 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -54,6 +54,7 @@ public class MapperFeatures implements FeatureSpecification { static final NodeFeature PATTERN_TEXT_RENAME = new NodeFeature("mapper.pattern_text_rename"); static final NodeFeature DISKBBQ_ON_DISK_RESCORING = new NodeFeature("mapper.vectors.diskbbq_on_disk_rescoring"); static final NodeFeature PROVIDE_INDEX_SORT_SETTING_DEFAULTS = new NodeFeature("mapper.provide_index_sort_setting_defaults"); + static final NodeFeature BBQ_HNSW_ON_DISK_RESCORING = new NodeFeature("mapper.vectors.bbq_hnsw_on_disk_rescoring"); @Override public Set getTestFeatures() { @@ -93,7 +94,8 @@ public Set getTestFeatures() { MATCH_ONLY_TEXT_BLOCK_LOADER_FIX, PATTERN_TEXT_RENAME, DISKBBQ_ON_DISK_RESCORING, - PROVIDE_INDEX_SORT_SETTING_DEFAULTS + PROVIDE_INDEX_SORT_SETTING_DEFAULTS, + BBQ_HNSW_ON_DISK_RESCORING ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c40986be0ffbf..1a506a5e94cdf 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -47,14 +47,15 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -391,6 +392,7 @@ private DenseVectorIndexOptions defaultIndexOptions(boolean defaultInt8Hnsw, boo return new BBQHnswIndexOptions( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + false, new RescoreVector(DEFAULT_OVERSAMPLE) ); } else if (defaultInt8Hnsw) { @@ -460,6 +462,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) { public enum ElementType { BYTE, FLOAT, + BFLOAT16, BIT; public static ElementType fromString(String name) { @@ -474,6 +477,7 @@ public String toString() { public static final Element BYTE_ELEMENT = new ByteElement(); public static final Element FLOAT_ELEMENT = new FloatElement(); + public static final Element BFLOAT16_ELEMENT = new BFloat16Element(); public static final Element BIT_ELEMENT = new BitElement(); public static final Map namesToElementType = Map.of( @@ -481,6 +485,8 @@ public String toString() { ElementType.BYTE, ElementType.FLOAT.toString(), ElementType.FLOAT, + ElementType.BFLOAT16.toString(), + ElementType.BFLOAT16, ElementType.BIT.toString(), ElementType.BIT ); @@ -490,6 +496,7 @@ public abstract static class Element { public static Element getElement(ElementType elementType) { return switch (elementType) { case FLOAT -> FLOAT_ELEMENT; + case BFLOAT16 -> BFLOAT16_ELEMENT; case BYTE -> BYTE_ELEMENT; case BIT -> BIT_ELEMENT; }; @@ -1055,6 +1062,29 @@ static UnaryOperator errorElementsAppender(float[] vector) { } } + private static class BFloat16Element extends FloatElement { + + @Override + public ElementType elementType() { + return ElementType.BFLOAT16; + } + + @Override + public void writeValue(ByteBuffer byteBuffer, float value) { + byteBuffer.putShort(BFloat16.floatToBFloat16(value)); + } + + @Override + public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException { + b.value(BFloat16.bFloat16ToFloat(byteBuffer.getShort())); + } + + @Override + public int getNumBytes(int dimensions) { + return dimensions * BFloat16.BYTES; + } + } + private static class BitElement extends ByteElement { @Override @@ -1122,7 +1152,7 @@ public enum VectorSimilarity { @Override float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { - case BYTE, FLOAT -> 1f / (1f + similarity * similarity); + case BYTE, FLOAT, BFLOAT16 -> 1f / (1f + similarity * similarity); case BIT -> (dim - similarity) / dim; }; } @@ -1137,14 +1167,14 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi float score(float similarity, ElementType elementType, int dim) { assert elementType != ElementType.BIT; return switch (elementType) { - case BYTE, FLOAT -> (1 + similarity) / 2f; + case BYTE, FLOAT, BFLOAT16 -> (1 + similarity) / 2f; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @Override public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersion, ElementType elementType) { - return indexVersion.onOrAfter(NORMALIZE_COSINE) && ElementType.FLOAT.equals(elementType) + return indexVersion.onOrAfter(NORMALIZE_COSINE) && (elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16) ? VectorSimilarityFunction.DOT_PRODUCT : VectorSimilarityFunction.COSINE; } @@ -1154,7 +1184,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15)); - case FLOAT -> (1 + similarity) / 2f; + case FLOAT, BFLOAT16 -> (1 + similarity) / 2f; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1168,7 +1198,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi @Override float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { - case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; + case BYTE, FLOAT, BFLOAT16 -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1457,8 +1487,11 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, onDiskRescore, false); + case BFLOAT16 -> new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, onDiskRescore, true); + default -> throw new AssertionError(); + }; } @Override @@ -2051,12 +2089,15 @@ public boolean updatableTo(DenseVectorIndexOptions update) { @Override boolean doEquals(DenseVectorIndexOptions other) { BBQHnswIndexOptions that = (BBQHnswIndexOptions) other; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(rescoreVector, that.rescoreVector); + return m == that.m + && efConstruction == that.efConstruction + && onDiskRescore == that.onDiskRescore + && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(m, efConstruction, rescoreVector); + return Objects.hash(m, efConstruction, onDiskRescore, rescoreVector); } @Override @@ -2070,6 +2111,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("type", type); builder.field("m", m); builder.field("ef_construction", efConstruction); + if (onDiskRescore) { + builder.field("on_disk_rescore", true); + } if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } @@ -2098,8 +2142,11 @@ static class BBQFlatIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - assert elementType == ElementType.FLOAT; - return new ES818BinaryQuantizedVectorsFormat(); + return switch (elementType) { + case FLOAT -> new ES93BinaryQuantizedVectorsFormat(false, false); + case BFLOAT16 -> new ES93BinaryQuantizedVectorsFormat(false, true); + default -> throw new AssertionError(); + }; } @Override @@ -2162,13 +2209,21 @@ static class BBQIVFIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - assert elementType == ElementType.FLOAT; - return new ES920DiskBBQVectorsFormat( - clusterSize, - ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - onDiskRescore, - false - ); + return switch (elementType) { + case FLOAT -> new ES920DiskBBQVectorsFormat( + clusterSize, + ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, + onDiskRescore, + false + ); + case BFLOAT16 -> new ES920DiskBBQVectorsFormat( + clusterSize, + ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, + onDiskRescore, + true + ); + default -> throw new AssertionError(); + }; } @Override @@ -2348,7 +2403,7 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) } Query knnQuery = switch (element.elementType()) { case BYTE -> createExactKnnByteQuery(queryVector.asByteVector()); - case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector()); + case FLOAT, BFLOAT16 -> createExactKnnFloatQuery(queryVector.asFloatVector()); case BIT -> createExactKnnBitQuery(queryVector.asByteVector()); }; if (vectorSimilarity != null) { @@ -2429,7 +2484,7 @@ public Query createKnnQuery( knnSearchStrategy, hnswEarlyTermination ); - case FLOAT -> createKnnFloatQuery( + case FLOAT, BFLOAT16 -> createKnnFloatQuery( queryVector.asFloatVector(), k, numCands, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java index 120682d185535..f0e61cb38b4bc 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java @@ -69,14 +69,14 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { if (indexed) { return switch (elementType) { case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); - case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); + case FLOAT, BFLOAT16 -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); }; } else { BinaryDocValues values = DocValues.getBinary(reader, field); return switch (elementType) { case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims); - case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); + case FLOAT, BFLOAT16 -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims); }; } @@ -138,7 +138,7 @@ public Object nextValue() { return vectorValue; } }; - case FLOAT -> new FormattedDocValues() { + case FLOAT, BFLOAT16 -> new FormattedDocValues() { float[] vector = new float[dims]; private FloatVectorValues floatVectorValues; // use when indexed private KnnVectorValues.DocIndexIterator iterator; // use when indexed diff --git a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java index 13be089753fb7..b29a0266c220c 100644 --- a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java +++ b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java @@ -210,7 +210,7 @@ public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatL1Norm(scoreScript, field, (List) queryVector); } @@ -320,7 +320,7 @@ public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatL2Norm(scoreScript, field, (List) queryVector); } @@ -478,7 +478,7 @@ public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatDotProduct(scoreScript, field, (List) queryVector); } @@ -547,7 +547,7 @@ public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fiel } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatCosineSimilarity(scoreScript, field, (List) queryVector); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java index 48b44df9732fd..bdb82d518d02f 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java @@ -138,7 +138,9 @@ public float[] next() { if (hasNext() == false) { throw new IllegalArgumentException("No more elements in the iterator"); } - BFloat16.bFloat16ToFloat(vectorValues, buffer); + for (int i = 0; i < buffer.length; i++) { + buffer[i] = BFloat16.bFloat16ToFloat(vectorValues.get()); + } idx++; return buffer; } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java index 9478508da88d0..c3e3ed0c93c98 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java @@ -22,14 +22,14 @@ private DenseVectorFieldMapperTestUtils() {} public static List getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) { return switch (elementType) { - case FLOAT, BYTE -> List.of(SimilarityMeasure.values()); + case FLOAT, BFLOAT16, BYTE -> List.of(SimilarityMeasure.values()); case BIT -> List.of(SimilarityMeasure.L2_NORM); }; } public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { return switch (elementType) { - case FLOAT, BYTE -> dimensions; + case FLOAT, BFLOAT16, BYTE -> dimensions; case BIT -> { assert dimensions % Byte.SIZE == 0; yield dimensions / Byte.SIZE; @@ -43,7 +43,7 @@ public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType } return switch (elementType) { - case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); + case FLOAT, BFLOAT16, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); case BIT -> { if (max < 8) { throw new IllegalArgumentException("max must be at least 8 for bit vectors"); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 5932180ac3c03..436136fe526da 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -69,7 +69,9 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DEFAULT_OVERSAMPLE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -2447,7 +2449,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) ft; return switch (vectorFieldType.getElementType()) { case BYTE -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions()); - case FLOAT -> randomNormalizedVector(vectorFieldType.getVectorDimensions()); + case FLOAT, BFLOAT16 -> randomNormalizedVector(vectorFieldType.getVectorDimensions()); case BIT -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions() / 8); }; } @@ -2896,14 +2898,14 @@ public void testKnnBBQHNSWVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=" + String expectedString = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=" + m + ", beamWidth=" + efConstruction - + ", flatVectorFormat=ES818BinaryQuantizedVectorsFormat(" - + "name=ES818BinaryQuantizedVectorsFormat, " - + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))"; - assertEquals(expectedString, knnVectorsFormat.toString()); + + ", flatVectorFormat=ES93BinaryQuantizedVectorsFormat(" + + "name=ES93BinaryQuantizedVectorsFormat, " + + "writeFlatVectorFormat=Lucene99FlatVectorsFormat"; + assertThat(knnVectorsFormat, hasToString(startsWith(expectedString))); } public void testKnnBBQIVFVectorsFormat() throws IOException { @@ -3042,7 +3044,7 @@ public SyntheticSourceExample example(int maxValues) throws IOException { Object value = switch (elementType) { case BYTE, BIT: yield randomList(dims, dims, ESTestCase::randomByte); - case FLOAT: + case FLOAT, BFLOAT16: yield randomList(dims, dims, ESTestCase::randomFloat); }; return new SyntheticSourceExample(value, value, this::mapping); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 333722a4b438f..efc71f029e781 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -124,6 +124,7 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsA new DenseVectorFieldMapper.BBQHnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomBoolean(), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.BBQFlatIndexOptions( @@ -164,7 +165,12 @@ private DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsHnswQua randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), rescoreVector ), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), rescoreVector) + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomBoolean(), + rescoreVector + ) ); } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 945f831e26231..d8ce4f1333055 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -245,7 +245,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que approxFilterQuery, expectedStrategy ); - case FLOAT -> new ESKnnFloatVectorQuery( + case FLOAT, BFLOAT16 -> new ESKnnFloatVectorQuery( VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, @@ -266,7 +266,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que yield new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD); } } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (filterQuery != null) { yield new BooleanQuery.Builder().add( new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD), diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 051b6dbf3e8fa..cb8f5a5ebf6ad 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -255,7 +255,7 @@ private static List generateEmbedding(String input, int dimensions, Dense // Copied from DenseVectorFieldMapperTestUtils due to dependency restrictions private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { return switch (elementType) { - case FLOAT, BYTE -> dimensions; + case FLOAT, BYTE, BFLOAT16 -> dimensions; case BIT -> { assert dimensions % Byte.SIZE == 0; yield dimensions / Byte.SIZE; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 1a8b162eb1b46..7f486f9122008 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1390,7 +1390,7 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDense int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, false, rescoreVector); } static SemanticTextIndexOptions defaultIndexOptions(IndexVersion indexVersionCreated, MinimalServiceSettings modelSettings) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java index a5fd194f12109..894f2a820461c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java @@ -97,7 +97,7 @@ public TextEmbeddingResults responseBody(SageMakerModel model, InvokeEndpoint return switch (model.apiServiceSettings().elementType()) { case BIT -> TextEmbeddingBinary.PARSER.apply(p, null); case BYTE -> TextEmbeddingBytes.PARSER.apply(p, null); - case FLOAT -> TextEmbeddingFloat.PARSER.apply(p, null); + case FLOAT, BFLOAT16 -> TextEmbeddingFloat.PARSER.apply(p, null); }; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index 175c3e90f798d..9bc1736a85c7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -269,7 +269,7 @@ private static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BYTE, BIT -> randomChunkedInferenceEmbeddingByte(model, inputs); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index f47d7c4c37261..b0c3830955b93 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -1410,7 +1410,7 @@ private static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDens int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, false, rescoreVector); } private static SemanticTextIndexOptions defaultBbqHnswSemanticTextIndexOptions() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index d1499f4009d0a..fb6b5b85c9414 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -190,7 +190,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbedding(Model mo return switch (model.getTaskType()) { case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); }; default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); @@ -222,7 +222,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Model model, List inputs) { DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType(); int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions()); - assert elementType == DenseVectorFieldMapper.ElementType.FLOAT; + assert elementType == DenseVectorFieldMapper.ElementType.FLOAT || elementType == DenseVectorFieldMapper.ElementType.BFLOAT16; List chunks = new ArrayList<>(); for (String input : inputs) { @@ -272,7 +272,7 @@ public static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); @@ -415,7 +415,7 @@ public static ChunkedInference toChunkedResult( ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText); double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType()); EmbeddingResults.Embedding embedding = switch (elementType) { - case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); + case FLOAT, BFLOAT16 -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values)); }; chunks.add(new EmbeddingResults.Chunk(embedding, offset)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index b2d7218720a57..ab4116a77c4ec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -278,7 +278,7 @@ private void assertTextEmbeddingLuceneQuery(Query query) { Query innerQuery = assertOuterBooleanQuery(query); Class expectedKnnQueryClass = switch (denseVectorElementType) { - case FLOAT -> KnnFloatVectorQuery.class; + case FLOAT, BFLOAT16 -> KnnFloatVectorQuery.class; case BYTE, BIT -> KnnByteVectorQuery.class; }; assertThat(innerQuery, instanceOf(expectedKnnQueryClass)); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java index b858b935c1483..60d6bfc5586ee 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java @@ -16,6 +16,7 @@ import org.elasticsearch.index.fielddata.SortedBinaryDocValues; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; +import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.BitRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; @@ -24,7 +25,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; final class RankVectorsDVLeafFieldData implements LeafFieldData { @@ -123,7 +123,47 @@ public Object nextValue() { VectorIterator iterator = new FloatRankVectorsDocValuesField.FloatVectorIterator(ref, vector, numVecs); while (iterator.hasNext()) { float[] v = iterator.next(); - vectors.add(Arrays.copyOf(v, v.length)); + vectors.add(v.clone()); + } + return vectors; + } + }; + case BFLOAT16 -> new FormattedDocValues() { + private final float[] vector = new float[dims]; + private BytesRef ref = null; + private int numVecs = -1; + private final BinaryDocValues binary; + { + try { + binary = DocValues.getBinary(reader, field); + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values", e); + } + } + + @Override + public boolean advanceExact(int docId) throws IOException { + if (binary == null || binary.advanceExact(docId) == false) { + return false; + } + ref = binary.binaryValue(); + assert ref.length % (Short.BYTES * dims) == 0; + numVecs = ref.length / (Short.BYTES * dims); + return true; + } + + @Override + public int docValueCount() { + return 1; + } + + @Override + public Object nextValue() { + List vectors = new ArrayList<>(numVecs); + VectorIterator iterator = new BFloat16RankVectorsDocValuesField.BFloat16VectorIterator(ref, vector, numVecs); + while (iterator.hasNext()) { + float[] v = iterator.next(); + vectors.add(v.clone()); } return vectors; } @@ -140,6 +180,7 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { case BYTE -> new ByteRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case FLOAT -> new FloatRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case BIT -> new BitRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); + case BFLOAT16 -> new BFloat16RankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); }; } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for multi-vector field!", e); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java index 1c533e9ec88cd..d846a54d0bc83 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java @@ -351,7 +351,7 @@ public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fiel yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list); } } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new MaxSimFloatDotProduct(scoreScript, field, (List>) queryVector); } diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java index ad29a191aace3..589cb9c5f3ed9 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java @@ -417,7 +417,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { } yield vectors; } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { float[][] vectors = new float[numVectors][vectorFieldType.getVectorDimensions()]; for (int i = 0; i < numVectors; i++) { for (int j = 0; j < vectorFieldType.getVectorDimensions(); j++) { @@ -473,7 +473,7 @@ public SyntheticSourceExample example(int maxValues) { Object value = switch (elementType) { case BYTE, BIT: yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomByte)); - case FLOAT: + case FLOAT, BFLOAT16: yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomFloat)); }; return new SyntheticSourceExample(value, value, this::mapping); diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java index 127ad6c7dbe43..d8094195a4508 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; +import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.RankVectors; @@ -52,6 +53,36 @@ public void testFloatGetVectorValueAndGetMagnitude() throws IOException { } } + public void testBFloat16GetVectorValueAndGetMagnitude() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] expectedMagnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(expectedMagnitudes); + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + assertEquals(vectors[i].length, field.size()); + assertEquals(dims, scriptDocValues.dims()); + Iterator iterator = scriptDocValues.getVectorValues(); + float[] magnitudes = scriptDocValues.getMagnitudes(); + assertEquals(expectedMagnitudes[i].length, magnitudes.length); + for (int j = 0; j < vectors[i].length; j++) { + assertTrue(iterator.hasNext()); + assertArrayEquals(vectors[i][j], iterator.next(), 0.0001f); + assertEquals(expectedMagnitudes[i][j], magnitudes[j], 0.0001f); + } + } + } + public void testByteGetVectorValueAndGetMagnitude() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -98,6 +129,36 @@ public void testFloatMetadataAndIterator() throws IOException { assertEquals(dv, RankVectors.EMPTY); } + public void testBFloat16MetadataAndIterator() throws IOException { + int dims = 3; + float[][][] vectors = new float[][][] { + fill(new float[3][dims], ElementType.BFLOAT16), + fill(new float[2][dims], ElementType.BFLOAT16) }; + float[][] magnitudes = new float[][] { new float[3], new float[2] }; + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(magnitudes); + + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + RankVectors dv = field.get(); + assertEquals(vectors[i].length, dv.size()); + assertFalse(dv.isEmpty()); + assertEquals(dims, dv.getDims()); + UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator); + assertEquals("Cannot iterate over single valued rank_vectors field, use get() instead", e.getMessage()); + } + field.setNextDocId(vectors.length); + RankVectors dv = field.get(); + assertEquals(dv, RankVectors.EMPTY); + } + public void testByteMetadataAndIterator() throws IOException { int dims = 3; float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.BYTE), fill(new float[2][dims], ElementType.BYTE) }; @@ -146,6 +207,30 @@ public void testFloatMissingValues() throws IOException { assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); } + public void testBFloat16MissingValues() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(magnitudes); + RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(3); + assertEquals(0, field.size()); + Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); + + e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); + } + public void testByteMissingValues() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -184,6 +269,32 @@ public void testFloatGetFunctionIsNotAccessible() throws IOException { ); } + public void testBFloat16GetFunctionIsNotAccessible() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(magnitudes); + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(0); + Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); + assertThat( + e.getMessage(), + containsString( + "accessing a rank-vectors field's value through 'get' or 'value' is not supported," + + " use 'vectorValues' or 'magnitudes' instead." + ) + ); + } + public void testByteGetFunctionIsNotAccessible() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -306,12 +417,11 @@ public static BytesRef mockEncodeDenseVector(float[][] values, ElementType eleme ByteBuffer byteBuffer = element.createByteBuffer(indexVersion, numBytes * values.length); for (float[] vector : values) { for (float value : vector) { - if (elementType == ElementType.FLOAT) { - byteBuffer.putFloat(value); - } else if (elementType == ElementType.BYTE || elementType == ElementType.BIT) { - byteBuffer.put((byte) value); - } else { - throw new IllegalStateException("unknown element_type [" + elementType + "]"); + switch (elementType) { + case FLOAT -> byteBuffer.putFloat(value); + case BFLOAT16 -> byteBuffer.putShort((short) (Float.floatToIntBits(value) >>> 16)); + case BYTE, BIT -> byteBuffer.put((byte) value); + default -> throw new IllegalStateException("unknown element_type [" + elementType + "]"); } } } From c2c52e7f1b7511cd1c45caac82952a7b59b96cfe Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 10 Oct 2025 10:15:43 +0000 Subject: [PATCH 04/46] [CI] Auto commit changes from spotless --- .../index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index 4884572d99fc6..a4fcf71cf94be 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -32,7 +32,6 @@ import java.io.IOException; -import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; From 2f8e975152a56d93486545410b71d10feeabe4de Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 11:48:47 +0100 Subject: [PATCH 05/46] Remove tripping assertion --- .../index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index fe26de0fe869c..a4fcf71cf94be 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -32,7 +32,6 @@ import java.io.IOException; -import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; @@ -70,7 +69,6 @@ public RandomVectorScorer getRandomVectorScorer( assert binarizedVectors.size() > 0 : "BinarizedByteVectorValues must have at least one vector for ES816BinaryFlatVectorsScorer"; OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer(); float[] centroid = binarizedVectors.getCentroid(); - assert similarityFunction != COSINE || VectorUtil.isUnitVector(target); float[] scratch = new float[vectorValues.dimension()]; int[] initial = new int[target.length]; byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; From 4030f47742603b6d9c0f26a0b7617f3b08fb3837 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 16:20:21 +0100 Subject: [PATCH 06/46] PR comments --- .../elasticsearch/test/knn/CmdLineArgs.java | 404 ------------------ .../index/codec/vectors/BFloat16.java | 2 +- .../diskbbq/ES920DiskBBQVectorsFormat.java | 6 +- .../ES93BinaryQuantizedVectorsFormat.java | 40 +- .../es93/ES93GenericFlatVectorsFormat.java | 42 +- .../ES93HnswBinaryQuantizedVectorsFormat.java | 4 +- .../es93/OffHeapBFloat16VectorValues.java | 10 +- .../vectors/DenseVectorFieldMapper.java | 3 +- .../BFloat16RankVectorsDocValuesField.java | 157 ------- ...S920DiskBBQBFloat16VectorsFormatTests.java | 96 ----- .../ES920DiskBBQVectorsFormatTests.java | 10 +- ...ES93BinaryQuantizedVectorsFormatTests.java | 2 +- 12 files changed, 37 insertions(+), 739 deletions(-) delete mode 100644 qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java delete mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java delete mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java deleted file mode 100644 index 27272418b29f0..0000000000000 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.test.knn; - -import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.PathUtils; -import org.elasticsearch.monitor.jvm.JvmInfo; -import org.elasticsearch.xcontent.ObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; -import java.nio.file.Path; -import java.util.List; -import java.util.Locale; - -/** - * Command line arguments for the KNN index tester. - * This class encapsulates all the parameters required to run the KNN index tests. - */ -record CmdLineArgs( - List docVectors, - Path queryVectors, - int numDocs, - int numQueries, - KnnIndexTester.IndexType indexType, - int numCandidates, - int k, - double[] visitPercentages, - int ivfClusterSize, - int overSamplingFactor, - int hnswM, - int hnswEfConstruction, - int searchThreads, - int numSearchers, - int indexThreads, - boolean reindex, - boolean forceMerge, - float filterSelectivity, - long seed, - VectorSimilarityFunction vectorSpace, - int rawVectorSize, - int quantizeBits, - VectorEncoding vectorEncoding, - int dimensions, - boolean earlyTermination, - KnnIndexTester.MergePolicyType mergePolicy, - double writerBufferSizeInMb, - int writerMaxBufferedDocs -) implements ToXContentObject { - - static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors"); - static final ParseField QUERY_VECTORS_FIELD = new ParseField("query_vectors"); - static final ParseField NUM_DOCS_FIELD = new ParseField("num_docs"); - static final ParseField NUM_QUERIES_FIELD = new ParseField("num_queries"); - static final ParseField INDEX_TYPE_FIELD = new ParseField("index_type"); - static final ParseField NUM_CANDIDATES_FIELD = new ParseField("num_candidates"); - static final ParseField K_FIELD = new ParseField("k"); - // static final ParseField N_PROBE_FIELD = new ParseField("n_probe"); - static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage"); - static final ParseField IVF_CLUSTER_SIZE_FIELD = new ParseField("ivf_cluster_size"); - static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor"); - static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m"); - static final ParseField HNSW_EF_CONSTRUCTION_FIELD = new ParseField("hnsw_ef_construction"); - static final ParseField NUM_SEARCHERS_FIELD = new ParseField("num_searchers"); - static final ParseField SEARCH_THREADS_FIELD = new ParseField("search_threads"); - static final ParseField INDEX_THREADS_FIELD = new ParseField("index_threads"); - static final ParseField REINDEX_FIELD = new ParseField("reindex"); - static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge"); - static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space"); - static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits"); - static final ParseField RAW_VECTOR_SIZE_FIELD = new ParseField("raw_vector_size"); - static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding"); - static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); - static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); - static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity"); - static final ParseField SEED_FIELD = new ParseField("seed"); - static final ParseField MERGE_POLICY_FIELD = new ParseField("merge_policy"); - static final ParseField WRITER_BUFFER_MB_FIELD = new ParseField("writer_buffer_mb"); - static final ParseField WRITER_BUFFER_DOCS_FIELD = new ParseField("writer_buffer_docs"); - - /** By default, in ES the default writer buffer size is 10% of the heap space - * (see {@code IndexingMemoryController.INDEX_BUFFER_SIZE_SETTING}). - * We configure the Java heap size for this tool in {@code build.gradle}; currently we default to 16GB, so in that case - * the buffer size would be 1.6GB. - */ - static final double DEFAULT_WRITER_BUFFER_MB = (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() / (1024.0 * 1024.0)) * 0.1; - - static CmdLineArgs fromXContent(XContentParser parser) throws IOException { - Builder builder = PARSER.apply(parser, null); - return builder.build(); - } - - static final ObjectParser PARSER = new ObjectParser<>("cmd_line_args", false, Builder::new); - - static { - PARSER.declareStringArray(Builder::setDocVectors, DOC_VECTORS_FIELD); - PARSER.declareString(Builder::setQueryVectors, QUERY_VECTORS_FIELD); - PARSER.declareInt(Builder::setNumDocs, NUM_DOCS_FIELD); - PARSER.declareInt(Builder::setNumQueries, NUM_QUERIES_FIELD); - PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD); - PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD); - PARSER.declareInt(Builder::setK, K_FIELD); - // PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD); - PARSER.declareDoubleArray(Builder::setVisitPercentages, VISIT_PERCENTAGE_FIELD); - PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD); - PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD); - PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD); - PARSER.declareInt(Builder::setHnswEfConstruction, HNSW_EF_CONSTRUCTION_FIELD); - PARSER.declareInt(Builder::setSearchThreads, SEARCH_THREADS_FIELD); - PARSER.declareInt(Builder::setNumSearchers, NUM_SEARCHERS_FIELD); - PARSER.declareInt(Builder::setIndexThreads, INDEX_THREADS_FIELD); - PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD); - PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD); - PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD); - PARSER.declareInt(Builder::setRawVectorSize, RAW_VECTOR_SIZE_FIELD); - PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD); - PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); - PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); - PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD); - PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD); - PARSER.declareLong(Builder::setSeed, SEED_FIELD); - PARSER.declareString(Builder::setMergePolicy, MERGE_POLICY_FIELD); - PARSER.declareDouble(Builder::setWriterBufferMb, WRITER_BUFFER_MB_FIELD); - PARSER.declareInt(Builder::setWriterMaxBufferedDocs, WRITER_BUFFER_DOCS_FIELD); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (docVectors != null) { - List docVectorsStrings = docVectors.stream().map(Path::toString).toList(); - builder.field(DOC_VECTORS_FIELD.getPreferredName(), docVectorsStrings); - } - if (queryVectors != null) { - builder.field(QUERY_VECTORS_FIELD.getPreferredName(), queryVectors.toString()); - } - builder.field(NUM_DOCS_FIELD.getPreferredName(), numDocs); - builder.field(NUM_QUERIES_FIELD.getPreferredName(), numQueries); - builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT)); - builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates); - builder.field(K_FIELD.getPreferredName(), k); - // builder.field(N_PROBE_FIELD.getPreferredName(), nProbes); - builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentages); - builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize); - builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor); - builder.field(HNSW_M_FIELD.getPreferredName(), hnswM); - builder.field(HNSW_EF_CONSTRUCTION_FIELD.getPreferredName(), hnswEfConstruction); - builder.field(SEARCH_THREADS_FIELD.getPreferredName(), searchThreads); - builder.field(NUM_SEARCHERS_FIELD.getPreferredName(), numSearchers); - builder.field(INDEX_THREADS_FIELD.getPreferredName(), indexThreads); - builder.field(REINDEX_FIELD.getPreferredName(), reindex); - builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge); - builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT)); - builder.field(RAW_VECTOR_SIZE_FIELD.getPreferredName(), rawVectorSize); - builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); - builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT)); - builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); - builder.field(EARLY_TERMINATION_FIELD.getPreferredName(), earlyTermination); - builder.field(FILTER_SELECTIVITY_FIELD.getPreferredName(), filterSelectivity); - builder.field(SEED_FIELD.getPreferredName(), seed); - builder.field(WRITER_BUFFER_MB_FIELD.getPreferredName(), writerBufferSizeInMb); - builder.field(WRITER_BUFFER_DOCS_FIELD.getPreferredName(), writerMaxBufferedDocs); - return builder.endObject(); - } - - @Override - public String toString() { - return Strings.toString(this, false, false); - } - - static class Builder { - private List docVectors; - private Path queryVectors; - private int numDocs = 1000; - private int numQueries = 100; - private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW; - private int numCandidates = 1000; - private int k = 10; - private double[] visitPercentages = new double[] { 1.0 }; - private int ivfClusterSize = 1000; - private int overSamplingFactor = 1; - private int hnswM = 16; - private int hnswEfConstruction = 200; - private int searchThreads = 1; - private int numSearchers = 1; - private int indexThreads = 1; - private boolean reindex = false; - private boolean forceMerge = false; - private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN; - private int rawVectorSize = 32; - private int quantizeBits = 8; - private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; - private int dimensions; - private boolean earlyTermination; - private float filterSelectivity = 1f; - private long seed = 1751900822751L; - private KnnIndexTester.MergePolicyType mergePolicy = null; - private double writerBufferSizeInMb = DEFAULT_WRITER_BUFFER_MB; - - /** - * Elasticsearch does not set this explicitly, and in Lucene this setting is - * disabled by default (writer flushes by RAM usage). - */ - private int writerMaxBufferedDocs = IndexWriterConfig.DISABLE_AUTO_FLUSH; - - public Builder setDocVectors(List docVectors) { - if (docVectors == null || docVectors.isEmpty()) { - throw new IllegalArgumentException("Document vectors path must be provided"); - } - // Convert list of strings to list of Paths - this.docVectors = docVectors.stream().map(PathUtils::get).toList(); - return this; - } - - public Builder setQueryVectors(String queryVectors) { - this.queryVectors = PathUtils.get(queryVectors); - return this; - } - - public Builder setNumDocs(int numDocs) { - this.numDocs = numDocs; - return this; - } - - public Builder setNumQueries(int numQueries) { - this.numQueries = numQueries; - return this; - } - - public Builder setIndexType(String indexType) { - this.indexType = KnnIndexTester.IndexType.valueOf(indexType.toUpperCase(Locale.ROOT)); - return this; - } - - public Builder setNumCandidates(int numCandidates) { - this.numCandidates = numCandidates; - return this; - } - - public Builder setK(int k) { - this.k = k; - return this; - } - - public Builder setVisitPercentages(List visitPercentages) { - this.visitPercentages = visitPercentages.stream().mapToDouble(Double::doubleValue).toArray(); - return this; - } - - public Builder setIvfClusterSize(int ivfClusterSize) { - this.ivfClusterSize = ivfClusterSize; - return this; - } - - public Builder setOverSamplingFactor(int overSamplingFactor) { - this.overSamplingFactor = overSamplingFactor; - return this; - } - - public Builder setHnswM(int hnswM) { - this.hnswM = hnswM; - return this; - } - - public Builder setHnswEfConstruction(int hnswEfConstruction) { - this.hnswEfConstruction = hnswEfConstruction; - return this; - } - - public Builder setSearchThreads(int searchThreads) { - this.searchThreads = searchThreads; - return this; - } - - public Builder setNumSearchers(int numSearchers) { - this.numSearchers = numSearchers; - return this; - } - - public Builder setIndexThreads(int indexThreads) { - this.indexThreads = indexThreads; - return this; - } - - public Builder setReindex(boolean reindex) { - this.reindex = reindex; - return this; - } - - public Builder setForceMerge(boolean forceMerge) { - this.forceMerge = forceMerge; - return this; - } - - public Builder setVectorSpace(String vectorSpace) { - this.vectorSpace = VectorSimilarityFunction.valueOf(vectorSpace.toUpperCase(Locale.ROOT)); - return this; - } - - public Builder setRawVectorSize(int rawVectorSize) { - this.rawVectorSize = rawVectorSize; - return this; - } - - public Builder setQuantizeBits(int quantizeBits) { - this.quantizeBits = quantizeBits; - return this; - } - - public Builder setVectorEncoding(String vectorEncoding) { - this.vectorEncoding = VectorEncoding.valueOf(vectorEncoding.toUpperCase(Locale.ROOT)); - return this; - } - - public Builder setDimensions(int dimensions) { - this.dimensions = dimensions; - return this; - } - - public Builder setEarlyTermination(Boolean patience) { - this.earlyTermination = patience; - return this; - } - - public Builder setFilterSelectivity(float filterSelectivity) { - this.filterSelectivity = filterSelectivity; - return this; - } - - public Builder setSeed(long seed) { - this.seed = seed; - return this; - } - - public Builder setMergePolicy(String mergePolicy) { - this.mergePolicy = KnnIndexTester.MergePolicyType.valueOf(mergePolicy.toUpperCase(Locale.ROOT)); - return this; - } - - public Builder setWriterBufferMb(double writerBufferSizeInMb) { - this.writerBufferSizeInMb = writerBufferSizeInMb; - return this; - } - - public Builder setWriterMaxBufferedDocs(int writerMaxBufferedDocs) { - this.writerMaxBufferedDocs = writerMaxBufferedDocs; - return this; - } - - public CmdLineArgs build() { - if (docVectors == null) { - throw new IllegalArgumentException("Document vectors path must be provided"); - } - if (dimensions <= 0 && dimensions != -1) { - throw new IllegalArgumentException( - "dimensions must be a positive integer or -1 for when dimension is available in the vector file" - ); - } - return new CmdLineArgs( - docVectors, - queryVectors, - numDocs, - numQueries, - indexType, - numCandidates, - k, - visitPercentages, - ivfClusterSize, - overSamplingFactor, - hnswM, - hnswEfConstruction, - searchThreads, - numSearchers, - indexThreads, - reindex, - forceMerge, - filterSelectivity, - seed, - vectorSpace, - rawVectorSize, - quantizeBits, - vectorEncoding, - dimensions, - earlyTermination, - mergePolicy, - writerBufferSizeInMb, - writerMaxBufferedDocs - ); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java index f178e1e61ba5d..8d25ab54d8ca1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java @@ -14,7 +14,7 @@ import java.nio.ByteOrder; import java.nio.ShortBuffer; -public class BFloat16 { +public final class BFloat16 { public static final int BYTES = Short.BYTES; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index cee32ac4ef470..a263275ed3342 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -88,10 +88,10 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private final DirectIOCapableFlatVectorsFormat rawVectorFormat; public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { - this(vectorPerCluster, centroidsPerParentCluster, false, false); + this(vectorPerCluster, centroidsPerParentCluster, false); } - public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO, boolean useBFloat16) { + public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO) { super(NAME); if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { throw new IllegalArgumentException( @@ -116,7 +116,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; this.useDirectIO = useDirectIO; - this.rawVectorFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; + this.rawVectorFormat = float32VectorFormat; } /** Constructs a format using the given graph construction parameters and scalar quantization. */ diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java index 2d5592314f198..7104c242ba1cb 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java @@ -25,14 +25,12 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; import java.io.IOException; -import java.util.Map; /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 @@ -91,35 +89,16 @@ public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsForm public static final String NAME = "ES93BinaryQuantizedVectorsFormat"; - private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( - FlatVectorScorerUtil.getLucene99FlatVectorsScorer() - ); - private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( - FlatVectorScorerUtil.getLucene99FlatVectorsScorer() - ); - - private static final Map supportedFormats = Map.of( - float32VectorFormat.getName(), - float32VectorFormat, - bfloat16VectorFormat.getName(), - bfloat16VectorFormat - ); - private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); - private final boolean useDirectIO; - private final DirectIOCapableFlatVectorsFormat rawFormat; - public ES93BinaryQuantizedVectorsFormat() { this(false, false); } - public ES93BinaryQuantizedVectorsFormat(boolean useDirectIO, boolean useBFloat16) { - super(NAME); - this.useDirectIO = useDirectIO; - this.rawFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; + public ES93BinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) { + super(NAME, useBFloat16, useDirectIO); } @Override @@ -127,21 +106,6 @@ protected FlatVectorsScorer flatVectorsScorer() { return scorer; } - @Override - protected boolean useDirectIOReads() { - return useDirectIO; - } - - @Override - protected DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat() { - return rawFormat; - } - - @Override - protected Map supportedReadFlatVectorsFormats() { - return supportedFormats; - } - @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new ES818BinaryQuantizedVectorsWriter(scorer, super.fieldsWriter(state), state); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index 526a4241ed89e..e70c3d16d1f26 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -9,6 +9,7 @@ package org.elasticsearch.index.codec.vectors.es93; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; @@ -34,28 +35,38 @@ public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFo VERSION_CURRENT ); - public ES93GenericFlatVectorsFormat(String name) { - super(name); - } + private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); - protected abstract DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat(); + private static final Map supportedFormats = Map.of( + float32VectorFormat.getName(), + float32VectorFormat, + bfloat16VectorFormat.getName(), + bfloat16VectorFormat + ); - protected abstract boolean useDirectIOReads(); + private final DirectIOCapableFlatVectorsFormat writeFormat; + private final boolean useDirectIO; - protected abstract Map supportedReadFlatVectorsFormats(); + public ES93GenericFlatVectorsFormat(String name, boolean useBFloat16, boolean useDirectIO) { + super(name); + writeFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; + this.useDirectIO = useDirectIO; + } @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - var flatFormat = writeFlatVectorsFormat(); - boolean directIO = useDirectIOReads(); - return new ES93GenericFlatVectorsWriter(META, flatFormat.getName(), directIO, state, flatFormat.fieldsWriter(state)); + return new ES93GenericFlatVectorsWriter(META, writeFormat.getName(), useDirectIO, state, writeFormat.fieldsWriter(state)); } @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - var readFormats = supportedReadFlatVectorsFormats(); return new ES93GenericFlatVectorsReader(META, state, (f, dio) -> { - var format = readFormats.get(f); + var format = supportedFormats.get(f); if (format == null) return null; return format.fieldsReader(state, dio); }); @@ -63,13 +74,6 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException @Override public String toString() { - return getName() - + "(name=" - + getName() - + ", writeFlatVectorFormat=" - + writeFlatVectorsFormat() - + ", readFlatVectorsFormats=" - + supportedReadFlatVectorsFormats().values() - + ")"; + return getName() + "(name=" + getName() + ", writeFlatVectorFormat=" + writeFormat + ")"; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index 579c42edc6288..c9cbe015c063e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -53,7 +53,7 @@ public ES93HnswBinaryQuantizedVectorsFormat() { */ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useDirectIO, boolean useBFloat16) { super(NAME, maxConn, beamWidth); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO, useBFloat16); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); } /** @@ -76,7 +76,7 @@ public ES93HnswBinaryQuantizedVectorsFormat( ExecutorService mergeExec ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO, useBFloat16); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java index 2038cb5232666..42f02d2d21366 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java @@ -21,7 +21,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; -import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; @@ -36,7 +35,7 @@ import java.io.IOException; -abstract class OffHeapBFloat16VectorValues extends FloatVectorValues implements HasIndexSlice { +abstract class OffHeapBFloat16VectorValues extends FloatVectorValues { protected final int dimension; protected final int size; @@ -62,7 +61,7 @@ abstract class OffHeapBFloat16VectorValues extends FloatVectorValues implements this.byteSize = byteSize; this.similarityFunction = similarityFunction; this.flatVectorsScorer = flatVectorsScorer; - bfloatBytes = new byte[dimension * 2]; + bfloatBytes = new byte[dimension * BFloat16.BYTES]; value = new float[dimension]; } @@ -76,11 +75,6 @@ public int size() { return size; } - @Override - public IndexInput getSlice() { - return slice; - } - @Override public float[] vectorValue(int targetOrd) throws IOException { if (lastOrd == targetOrd) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c40986be0ffbf..33d99c5628732 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2166,8 +2166,7 @@ KnnVectorsFormat getVectorsFormat(ElementType elementType) { return new ES920DiskBBQVectorsFormat( clusterSize, ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - onDiskRescore, - false + onDiskRescore ); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java deleted file mode 100644 index 48b44df9732fd..0000000000000 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.script.field.vectors; - -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.index.codec.vectors.BFloat16; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.ShortBuffer; -import java.util.Iterator; - -public class BFloat16RankVectorsDocValuesField extends RankVectorsDocValuesField { - - private final BinaryDocValues input; - private final BinaryDocValues magnitudes; - private boolean decoded; - private final int dims; - private BytesRef value; - private BytesRef magnitudesValue; - private BFloat16VectorIterator vectorValues; - private int numVectors; - private float[] buffer; - - public BFloat16RankVectorsDocValuesField( - BinaryDocValues input, - BinaryDocValues magnitudes, - String name, - ElementType elementType, - int dims - ) { - super(name, elementType); - this.input = input; - this.magnitudes = magnitudes; - this.dims = dims; - this.buffer = new float[dims]; - } - - @Override - public void setNextDocId(int docId) throws IOException { - decoded = false; - if (input.advanceExact(docId)) { - boolean magnitudesFound = magnitudes.advanceExact(docId); - assert magnitudesFound; - - value = input.binaryValue(); - assert value.length % (BFloat16.BYTES * dims) == 0; - numVectors = value.length / (BFloat16.BYTES * dims); - magnitudesValue = magnitudes.binaryValue(); - assert magnitudesValue.length == (Float.BYTES * numVectors); - } else { - value = null; - magnitudesValue = null; - numVectors = 0; - } - } - - @Override - public RankVectorsScriptDocValues toScriptDocValues() { - return new RankVectorsScriptDocValues(this, dims); - } - - @Override - public boolean isEmpty() { - return value == null; - } - - @Override - public RankVectors get() { - if (isEmpty()) { - return RankVectors.EMPTY; - } - decodeVectorIfNecessary(); - return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); - } - - @Override - public RankVectors get(RankVectors defaultValue) { - if (isEmpty()) { - return defaultValue; - } - decodeVectorIfNecessary(); - return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); - } - - @Override - public RankVectors getInternal() { - return get(null); - } - - @Override - public int size() { - return value == null ? 0 : value.length / (BFloat16.BYTES * dims); - } - - private void decodeVectorIfNecessary() { - if (decoded == false && value != null) { - vectorValues = new BFloat16VectorIterator(value, buffer, numVectors); - decoded = true; - } - } - - public static class BFloat16VectorIterator implements VectorIterator { - private final float[] buffer; - private final ShortBuffer vectorValues; - private final BytesRef vectorValueBytesRef; - private final int size; - private int idx = 0; - - public BFloat16VectorIterator(BytesRef vectorValues, float[] buffer, int size) { - assert vectorValues.length == (buffer.length * BFloat16.BYTES * size); - this.vectorValueBytesRef = vectorValues; - this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length) - .order(ByteOrder.LITTLE_ENDIAN) - .asShortBuffer(); - this.size = size; - this.buffer = buffer; - } - - @Override - public boolean hasNext() { - return idx < size; - } - - @Override - public float[] next() { - if (hasNext() == false) { - throw new IllegalArgumentException("No more elements in the iterator"); - } - BFloat16.bFloat16ToFloat(vectorValues, buffer); - idx++; - return buffer; - } - - @Override - public Iterator copy() { - return new BFloat16VectorIterator(vectorValueBytesRef, new float[buffer.length], size); - } - - @Override - public void reset() { - idx = 0; - vectorValues.rewind(); - } - } -} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java deleted file mode 100644 index 38548deff5b45..0000000000000 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.index.codec.vectors.diskbbq; - -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import static org.hamcrest.Matchers.closeTo; - -public class ES920DiskBBQBFloat16VectorsFormatTests extends ES920DiskBBQVectorsFormatTests { - @Override - boolean useBFloat16() { - return true; - } - - @Override - public void testEmptyByteVectorData() throws Exception { - // no bytes - } - - @Override - public void testMergingWithDifferentByteKnnFields() throws Exception { - // no bytes - } - - @Override - public void testByteVectorScorerIteration() throws Exception { - // no bytes - } - - @Override - public void testSortedIndexBytes() throws Exception { - // no bytes - } - - @Override - public void testMismatchedFields() throws Exception { - // no bytes - } - - @Override - public void testRandomBytes() throws Exception { - // no bytes - } - - @Override - public void testWriterRamEstimate() throws Exception { - // estimate is different due to bfloat16 - } - - @Override - public void testRandom() throws Exception { - AssertionError err = expectThrows(AssertionError.class, super::testRandom); - assertFloatsWithinBounds(err); - } - - @Override - public void testSparseVectors() throws Exception { - AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); - assertFloatsWithinBounds(err); - } - - @Override - public void testVectorValuesReportCorrectDocs() throws Exception { - AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); - assertFloatsWithinBounds(err); - } - - @Override - public void testRandomWithUpdatesAndGraph() throws Exception { - AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); - assertFloatsWithinBounds(err); - } - - private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); - - private static void assertFloatsWithinBounds(AssertionError error) { - Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); - if (m.matches() == false) { - throw error; // nothing to do with us, just rethrow - } - - // numbers just need to be in the same vicinity - double expected = Double.parseDouble(m.group(1)); - double actual = Double.parseDouble(m.group(2)); - double allowedError = expected * 0.01; // within 1% - assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); - } -} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index 1535f71e3ba54..29e6c59d995be 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -63,10 +63,6 @@ public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase private KnnVectorsFormat format; - boolean useBFloat16() { - return false; - } - @Before @Override public void setUp() throws Exception { @@ -74,16 +70,14 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), - random().nextBoolean(), - useBFloat16() + random().nextBoolean() ); } else { // run with low numbers to force many clusters with parents format = new ES920DiskBBQVectorsFormat( random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), - random().nextBoolean(), - useBFloat16() + random().nextBoolean() ); } super.setUp(); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java index 20689e773ee79..1fd99df020ebf 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -84,7 +84,7 @@ boolean useBFloat16() { @Override public void setUp() throws Exception { - format = new ES93BinaryQuantizedVectorsFormat(random().nextBoolean(), useBFloat16()); + format = new ES93BinaryQuantizedVectorsFormat(useBFloat16(), random().nextBoolean()); super.setUp(); } From e910cd5dd5fff95e08727a6c43ee9555f9865452 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 16:25:12 +0100 Subject: [PATCH 07/46] Propagate files across --- .../elasticsearch/test/knn/CmdLineArgs.java | 404 ++++++++++++++++++ .../diskbbq/ES920DiskBBQVectorsFormat.java | 6 +- ...S920DiskBBQBFloat16VectorsFormatTests.java | 96 +++++ .../ES920DiskBBQVectorsFormatTests.java | 10 +- 4 files changed, 511 insertions(+), 5 deletions(-) create mode 100644 qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java new file mode 100644 index 0000000000000..27272418b29f0 --- /dev/null +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -0,0 +1,404 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.test.knn; + +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.PathUtils; +import org.elasticsearch.monitor.jvm.JvmInfo; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; + +/** + * Command line arguments for the KNN index tester. + * This class encapsulates all the parameters required to run the KNN index tests. + */ +record CmdLineArgs( + List docVectors, + Path queryVectors, + int numDocs, + int numQueries, + KnnIndexTester.IndexType indexType, + int numCandidates, + int k, + double[] visitPercentages, + int ivfClusterSize, + int overSamplingFactor, + int hnswM, + int hnswEfConstruction, + int searchThreads, + int numSearchers, + int indexThreads, + boolean reindex, + boolean forceMerge, + float filterSelectivity, + long seed, + VectorSimilarityFunction vectorSpace, + int rawVectorSize, + int quantizeBits, + VectorEncoding vectorEncoding, + int dimensions, + boolean earlyTermination, + KnnIndexTester.MergePolicyType mergePolicy, + double writerBufferSizeInMb, + int writerMaxBufferedDocs +) implements ToXContentObject { + + static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors"); + static final ParseField QUERY_VECTORS_FIELD = new ParseField("query_vectors"); + static final ParseField NUM_DOCS_FIELD = new ParseField("num_docs"); + static final ParseField NUM_QUERIES_FIELD = new ParseField("num_queries"); + static final ParseField INDEX_TYPE_FIELD = new ParseField("index_type"); + static final ParseField NUM_CANDIDATES_FIELD = new ParseField("num_candidates"); + static final ParseField K_FIELD = new ParseField("k"); + // static final ParseField N_PROBE_FIELD = new ParseField("n_probe"); + static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage"); + static final ParseField IVF_CLUSTER_SIZE_FIELD = new ParseField("ivf_cluster_size"); + static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor"); + static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m"); + static final ParseField HNSW_EF_CONSTRUCTION_FIELD = new ParseField("hnsw_ef_construction"); + static final ParseField NUM_SEARCHERS_FIELD = new ParseField("num_searchers"); + static final ParseField SEARCH_THREADS_FIELD = new ParseField("search_threads"); + static final ParseField INDEX_THREADS_FIELD = new ParseField("index_threads"); + static final ParseField REINDEX_FIELD = new ParseField("reindex"); + static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge"); + static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space"); + static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits"); + static final ParseField RAW_VECTOR_SIZE_FIELD = new ParseField("raw_vector_size"); + static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding"); + static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); + static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity"); + static final ParseField SEED_FIELD = new ParseField("seed"); + static final ParseField MERGE_POLICY_FIELD = new ParseField("merge_policy"); + static final ParseField WRITER_BUFFER_MB_FIELD = new ParseField("writer_buffer_mb"); + static final ParseField WRITER_BUFFER_DOCS_FIELD = new ParseField("writer_buffer_docs"); + + /** By default, in ES the default writer buffer size is 10% of the heap space + * (see {@code IndexingMemoryController.INDEX_BUFFER_SIZE_SETTING}). + * We configure the Java heap size for this tool in {@code build.gradle}; currently we default to 16GB, so in that case + * the buffer size would be 1.6GB. + */ + static final double DEFAULT_WRITER_BUFFER_MB = (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() / (1024.0 * 1024.0)) * 0.1; + + static CmdLineArgs fromXContent(XContentParser parser) throws IOException { + Builder builder = PARSER.apply(parser, null); + return builder.build(); + } + + static final ObjectParser PARSER = new ObjectParser<>("cmd_line_args", false, Builder::new); + + static { + PARSER.declareStringArray(Builder::setDocVectors, DOC_VECTORS_FIELD); + PARSER.declareString(Builder::setQueryVectors, QUERY_VECTORS_FIELD); + PARSER.declareInt(Builder::setNumDocs, NUM_DOCS_FIELD); + PARSER.declareInt(Builder::setNumQueries, NUM_QUERIES_FIELD); + PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD); + PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD); + PARSER.declareInt(Builder::setK, K_FIELD); + // PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD); + PARSER.declareDoubleArray(Builder::setVisitPercentages, VISIT_PERCENTAGE_FIELD); + PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD); + PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD); + PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD); + PARSER.declareInt(Builder::setHnswEfConstruction, HNSW_EF_CONSTRUCTION_FIELD); + PARSER.declareInt(Builder::setSearchThreads, SEARCH_THREADS_FIELD); + PARSER.declareInt(Builder::setNumSearchers, NUM_SEARCHERS_FIELD); + PARSER.declareInt(Builder::setIndexThreads, INDEX_THREADS_FIELD); + PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD); + PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD); + PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD); + PARSER.declareInt(Builder::setRawVectorSize, RAW_VECTOR_SIZE_FIELD); + PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD); + PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); + PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); + PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD); + PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD); + PARSER.declareLong(Builder::setSeed, SEED_FIELD); + PARSER.declareString(Builder::setMergePolicy, MERGE_POLICY_FIELD); + PARSER.declareDouble(Builder::setWriterBufferMb, WRITER_BUFFER_MB_FIELD); + PARSER.declareInt(Builder::setWriterMaxBufferedDocs, WRITER_BUFFER_DOCS_FIELD); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (docVectors != null) { + List docVectorsStrings = docVectors.stream().map(Path::toString).toList(); + builder.field(DOC_VECTORS_FIELD.getPreferredName(), docVectorsStrings); + } + if (queryVectors != null) { + builder.field(QUERY_VECTORS_FIELD.getPreferredName(), queryVectors.toString()); + } + builder.field(NUM_DOCS_FIELD.getPreferredName(), numDocs); + builder.field(NUM_QUERIES_FIELD.getPreferredName(), numQueries); + builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT)); + builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates); + builder.field(K_FIELD.getPreferredName(), k); + // builder.field(N_PROBE_FIELD.getPreferredName(), nProbes); + builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentages); + builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize); + builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor); + builder.field(HNSW_M_FIELD.getPreferredName(), hnswM); + builder.field(HNSW_EF_CONSTRUCTION_FIELD.getPreferredName(), hnswEfConstruction); + builder.field(SEARCH_THREADS_FIELD.getPreferredName(), searchThreads); + builder.field(NUM_SEARCHERS_FIELD.getPreferredName(), numSearchers); + builder.field(INDEX_THREADS_FIELD.getPreferredName(), indexThreads); + builder.field(REINDEX_FIELD.getPreferredName(), reindex); + builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge); + builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT)); + builder.field(RAW_VECTOR_SIZE_FIELD.getPreferredName(), rawVectorSize); + builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); + builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT)); + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + builder.field(EARLY_TERMINATION_FIELD.getPreferredName(), earlyTermination); + builder.field(FILTER_SELECTIVITY_FIELD.getPreferredName(), filterSelectivity); + builder.field(SEED_FIELD.getPreferredName(), seed); + builder.field(WRITER_BUFFER_MB_FIELD.getPreferredName(), writerBufferSizeInMb); + builder.field(WRITER_BUFFER_DOCS_FIELD.getPreferredName(), writerMaxBufferedDocs); + return builder.endObject(); + } + + @Override + public String toString() { + return Strings.toString(this, false, false); + } + + static class Builder { + private List docVectors; + private Path queryVectors; + private int numDocs = 1000; + private int numQueries = 100; + private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW; + private int numCandidates = 1000; + private int k = 10; + private double[] visitPercentages = new double[] { 1.0 }; + private int ivfClusterSize = 1000; + private int overSamplingFactor = 1; + private int hnswM = 16; + private int hnswEfConstruction = 200; + private int searchThreads = 1; + private int numSearchers = 1; + private int indexThreads = 1; + private boolean reindex = false; + private boolean forceMerge = false; + private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN; + private int rawVectorSize = 32; + private int quantizeBits = 8; + private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; + private int dimensions; + private boolean earlyTermination; + private float filterSelectivity = 1f; + private long seed = 1751900822751L; + private KnnIndexTester.MergePolicyType mergePolicy = null; + private double writerBufferSizeInMb = DEFAULT_WRITER_BUFFER_MB; + + /** + * Elasticsearch does not set this explicitly, and in Lucene this setting is + * disabled by default (writer flushes by RAM usage). + */ + private int writerMaxBufferedDocs = IndexWriterConfig.DISABLE_AUTO_FLUSH; + + public Builder setDocVectors(List docVectors) { + if (docVectors == null || docVectors.isEmpty()) { + throw new IllegalArgumentException("Document vectors path must be provided"); + } + // Convert list of strings to list of Paths + this.docVectors = docVectors.stream().map(PathUtils::get).toList(); + return this; + } + + public Builder setQueryVectors(String queryVectors) { + this.queryVectors = PathUtils.get(queryVectors); + return this; + } + + public Builder setNumDocs(int numDocs) { + this.numDocs = numDocs; + return this; + } + + public Builder setNumQueries(int numQueries) { + this.numQueries = numQueries; + return this; + } + + public Builder setIndexType(String indexType) { + this.indexType = KnnIndexTester.IndexType.valueOf(indexType.toUpperCase(Locale.ROOT)); + return this; + } + + public Builder setNumCandidates(int numCandidates) { + this.numCandidates = numCandidates; + return this; + } + + public Builder setK(int k) { + this.k = k; + return this; + } + + public Builder setVisitPercentages(List visitPercentages) { + this.visitPercentages = visitPercentages.stream().mapToDouble(Double::doubleValue).toArray(); + return this; + } + + public Builder setIvfClusterSize(int ivfClusterSize) { + this.ivfClusterSize = ivfClusterSize; + return this; + } + + public Builder setOverSamplingFactor(int overSamplingFactor) { + this.overSamplingFactor = overSamplingFactor; + return this; + } + + public Builder setHnswM(int hnswM) { + this.hnswM = hnswM; + return this; + } + + public Builder setHnswEfConstruction(int hnswEfConstruction) { + this.hnswEfConstruction = hnswEfConstruction; + return this; + } + + public Builder setSearchThreads(int searchThreads) { + this.searchThreads = searchThreads; + return this; + } + + public Builder setNumSearchers(int numSearchers) { + this.numSearchers = numSearchers; + return this; + } + + public Builder setIndexThreads(int indexThreads) { + this.indexThreads = indexThreads; + return this; + } + + public Builder setReindex(boolean reindex) { + this.reindex = reindex; + return this; + } + + public Builder setForceMerge(boolean forceMerge) { + this.forceMerge = forceMerge; + return this; + } + + public Builder setVectorSpace(String vectorSpace) { + this.vectorSpace = VectorSimilarityFunction.valueOf(vectorSpace.toUpperCase(Locale.ROOT)); + return this; + } + + public Builder setRawVectorSize(int rawVectorSize) { + this.rawVectorSize = rawVectorSize; + return this; + } + + public Builder setQuantizeBits(int quantizeBits) { + this.quantizeBits = quantizeBits; + return this; + } + + public Builder setVectorEncoding(String vectorEncoding) { + this.vectorEncoding = VectorEncoding.valueOf(vectorEncoding.toUpperCase(Locale.ROOT)); + return this; + } + + public Builder setDimensions(int dimensions) { + this.dimensions = dimensions; + return this; + } + + public Builder setEarlyTermination(Boolean patience) { + this.earlyTermination = patience; + return this; + } + + public Builder setFilterSelectivity(float filterSelectivity) { + this.filterSelectivity = filterSelectivity; + return this; + } + + public Builder setSeed(long seed) { + this.seed = seed; + return this; + } + + public Builder setMergePolicy(String mergePolicy) { + this.mergePolicy = KnnIndexTester.MergePolicyType.valueOf(mergePolicy.toUpperCase(Locale.ROOT)); + return this; + } + + public Builder setWriterBufferMb(double writerBufferSizeInMb) { + this.writerBufferSizeInMb = writerBufferSizeInMb; + return this; + } + + public Builder setWriterMaxBufferedDocs(int writerMaxBufferedDocs) { + this.writerMaxBufferedDocs = writerMaxBufferedDocs; + return this; + } + + public CmdLineArgs build() { + if (docVectors == null) { + throw new IllegalArgumentException("Document vectors path must be provided"); + } + if (dimensions <= 0 && dimensions != -1) { + throw new IllegalArgumentException( + "dimensions must be a positive integer or -1 for when dimension is available in the vector file" + ); + } + return new CmdLineArgs( + docVectors, + queryVectors, + numDocs, + numQueries, + indexType, + numCandidates, + k, + visitPercentages, + ivfClusterSize, + overSamplingFactor, + hnswM, + hnswEfConstruction, + searchThreads, + numSearchers, + indexThreads, + reindex, + forceMerge, + filterSelectivity, + seed, + vectorSpace, + rawVectorSize, + quantizeBits, + vectorEncoding, + dimensions, + earlyTermination, + mergePolicy, + writerBufferSizeInMb, + writerMaxBufferedDocs + ); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index a263275ed3342..cee32ac4ef470 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -88,10 +88,10 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private final DirectIOCapableFlatVectorsFormat rawVectorFormat; public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { - this(vectorPerCluster, centroidsPerParentCluster, false); + this(vectorPerCluster, centroidsPerParentCluster, false, false); } - public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO) { + public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO, boolean useBFloat16) { super(NAME); if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { throw new IllegalArgumentException( @@ -116,7 +116,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; this.useDirectIO = useDirectIO; - this.rawVectorFormat = float32VectorFormat; + this.rawVectorFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; } /** Constructs a format using the given graph construction parameters and scalar quantization. */ diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..38548deff5b45 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.closeTo; + +public class ES920DiskBBQBFloat16VectorsFormatTests extends ES920DiskBBQVectorsFormatTests { + @Override + boolean useBFloat16() { + return true; + } + + @Override + public void testEmptyByteVectorData() throws Exception { + // no bytes + } + + @Override + public void testMergingWithDifferentByteKnnFields() throws Exception { + // no bytes + } + + @Override + public void testByteVectorScorerIteration() throws Exception { + // no bytes + } + + @Override + public void testSortedIndexBytes() throws Exception { + // no bytes + } + + @Override + public void testMismatchedFields() throws Exception { + // no bytes + } + + @Override + public void testRandomBytes() throws Exception { + // no bytes + } + + @Override + public void testWriterRamEstimate() throws Exception { + // estimate is different due to bfloat16 + } + + @Override + public void testRandom() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandom); + assertFloatsWithinBounds(err); + } + + @Override + public void testSparseVectors() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); + assertFloatsWithinBounds(err); + } + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); + assertFloatsWithinBounds(err); + } + + @Override + public void testRandomWithUpdatesAndGraph() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); + assertFloatsWithinBounds(err); + } + + private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); + + private static void assertFloatsWithinBounds(AssertionError error) { + Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); + if (m.matches() == false) { + throw error; // nothing to do with us, just rethrow + } + + // numbers just need to be in the same vicinity + double expected = Double.parseDouble(m.group(1)); + double actual = Double.parseDouble(m.group(2)); + double allowedError = expected * 0.01; // within 1% + assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index 29e6c59d995be..1535f71e3ba54 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -63,6 +63,10 @@ public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase private KnnVectorsFormat format; + boolean useBFloat16() { + return false; + } + @Before @Override public void setUp() throws Exception { @@ -70,14 +74,16 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), - random().nextBoolean() + random().nextBoolean(), + useBFloat16() ); } else { // run with low numbers to force many clusters with parents format = new ES920DiskBBQVectorsFormat( random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), - random().nextBoolean() + random().nextBoolean(), + useBFloat16() ); } super.setUp(); From 5ef4cf1932b858e99beef7607da12797b2842a96 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 16:20:21 +0100 Subject: [PATCH 08/46] PR comments --- .../elasticsearch/test/knn/CmdLineArgs.java | 11 -- .../index/codec/vectors/BFloat16.java | 2 +- .../diskbbq/ES920DiskBBQVectorsFormat.java | 14 +- .../ES93BinaryQuantizedVectorsFormat.java | 40 +---- .../es93/ES93GenericFlatVectorsFormat.java | 42 ++--- .../ES93HnswBinaryQuantizedVectorsFormat.java | 4 +- .../es93/OffHeapBFloat16VectorValues.java | 10 +- .../vectors/DenseVectorFieldMapper.java | 3 +- .../BFloat16RankVectorsDocValuesField.java | 157 ------------------ ...S920DiskBBQBFloat16VectorsFormatTests.java | 96 ----------- .../ES920DiskBBQVectorsFormatTests.java | 10 +- ...ES93BinaryQuantizedVectorsFormatTests.java | 2 +- 12 files changed, 38 insertions(+), 353 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java delete mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 27272418b29f0..773ad3c8da682 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -51,7 +51,6 @@ record CmdLineArgs( float filterSelectivity, long seed, VectorSimilarityFunction vectorSpace, - int rawVectorSize, int quantizeBits, VectorEncoding vectorEncoding, int dimensions, @@ -81,7 +80,6 @@ record CmdLineArgs( static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge"); static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space"); static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits"); - static final ParseField RAW_VECTOR_SIZE_FIELD = new ParseField("raw_vector_size"); static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding"); static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); @@ -125,7 +123,6 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD); PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD); PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD); - PARSER.declareInt(Builder::setRawVectorSize, RAW_VECTOR_SIZE_FIELD); PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD); PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); @@ -164,7 +161,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(REINDEX_FIELD.getPreferredName(), reindex); builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge); builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT)); - builder.field(RAW_VECTOR_SIZE_FIELD.getPreferredName(), rawVectorSize); builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT)); builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); @@ -200,7 +196,6 @@ static class Builder { private boolean reindex = false; private boolean forceMerge = false; private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN; - private int rawVectorSize = 32; private int quantizeBits = 8; private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; private int dimensions; @@ -310,11 +305,6 @@ public Builder setVectorSpace(String vectorSpace) { return this; } - public Builder setRawVectorSize(int rawVectorSize) { - this.rawVectorSize = rawVectorSize; - return this; - } - public Builder setQuantizeBits(int quantizeBits) { this.quantizeBits = quantizeBits; return this; @@ -390,7 +380,6 @@ public CmdLineArgs build() { filterSelectivity, seed, vectorSpace, - rawVectorSize, quantizeBits, vectorEncoding, dimensions, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java index f178e1e61ba5d..8d25ab54d8ca1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java @@ -14,7 +14,7 @@ import java.nio.ByteOrder; import java.nio.ShortBuffer; -public class BFloat16 { +public final class BFloat16 { public static final int BYTES = Short.BYTES; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index cee32ac4ef470..99bc9a9d7bdb2 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -18,7 +18,6 @@ import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es93.DirectIOCapableLucene99FlatVectorsFormat; -import org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat; import java.io.IOException; import java.util.Map; @@ -62,14 +61,9 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); - private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( - FlatVectorScorerUtil.getLucene99FlatVectorsScorer() - ); private static final Map supportedFormats = Map.of( float32VectorFormat.getName(), - float32VectorFormat, - bfloat16VectorFormat.getName(), - bfloat16VectorFormat + float32VectorFormat ); // This dynamically sets the cluster probe based on the `k` requested and the number of clusters. @@ -88,10 +82,10 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private final DirectIOCapableFlatVectorsFormat rawVectorFormat; public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { - this(vectorPerCluster, centroidsPerParentCluster, false, false); + this(vectorPerCluster, centroidsPerParentCluster, false); } - public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO, boolean useBFloat16) { + public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO) { super(NAME); if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { throw new IllegalArgumentException( @@ -116,7 +110,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; this.useDirectIO = useDirectIO; - this.rawVectorFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; + this.rawVectorFormat = float32VectorFormat; } /** Constructs a format using the given graph construction parameters and scalar quantization. */ diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java index 2d5592314f198..7104c242ba1cb 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java @@ -25,14 +25,12 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; import java.io.IOException; -import java.util.Map; /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 @@ -91,35 +89,16 @@ public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsForm public static final String NAME = "ES93BinaryQuantizedVectorsFormat"; - private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( - FlatVectorScorerUtil.getLucene99FlatVectorsScorer() - ); - private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( - FlatVectorScorerUtil.getLucene99FlatVectorsScorer() - ); - - private static final Map supportedFormats = Map.of( - float32VectorFormat.getName(), - float32VectorFormat, - bfloat16VectorFormat.getName(), - bfloat16VectorFormat - ); - private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); - private final boolean useDirectIO; - private final DirectIOCapableFlatVectorsFormat rawFormat; - public ES93BinaryQuantizedVectorsFormat() { this(false, false); } - public ES93BinaryQuantizedVectorsFormat(boolean useDirectIO, boolean useBFloat16) { - super(NAME); - this.useDirectIO = useDirectIO; - this.rawFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; + public ES93BinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) { + super(NAME, useBFloat16, useDirectIO); } @Override @@ -127,21 +106,6 @@ protected FlatVectorsScorer flatVectorsScorer() { return scorer; } - @Override - protected boolean useDirectIOReads() { - return useDirectIO; - } - - @Override - protected DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat() { - return rawFormat; - } - - @Override - protected Map supportedReadFlatVectorsFormats() { - return supportedFormats; - } - @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new ES818BinaryQuantizedVectorsWriter(scorer, super.fieldsWriter(state), state); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index 526a4241ed89e..e70c3d16d1f26 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -9,6 +9,7 @@ package org.elasticsearch.index.codec.vectors.es93; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; @@ -34,28 +35,38 @@ public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFo VERSION_CURRENT ); - public ES93GenericFlatVectorsFormat(String name) { - super(name); - } + private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); - protected abstract DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat(); + private static final Map supportedFormats = Map.of( + float32VectorFormat.getName(), + float32VectorFormat, + bfloat16VectorFormat.getName(), + bfloat16VectorFormat + ); - protected abstract boolean useDirectIOReads(); + private final DirectIOCapableFlatVectorsFormat writeFormat; + private final boolean useDirectIO; - protected abstract Map supportedReadFlatVectorsFormats(); + public ES93GenericFlatVectorsFormat(String name, boolean useBFloat16, boolean useDirectIO) { + super(name); + writeFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; + this.useDirectIO = useDirectIO; + } @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - var flatFormat = writeFlatVectorsFormat(); - boolean directIO = useDirectIOReads(); - return new ES93GenericFlatVectorsWriter(META, flatFormat.getName(), directIO, state, flatFormat.fieldsWriter(state)); + return new ES93GenericFlatVectorsWriter(META, writeFormat.getName(), useDirectIO, state, writeFormat.fieldsWriter(state)); } @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - var readFormats = supportedReadFlatVectorsFormats(); return new ES93GenericFlatVectorsReader(META, state, (f, dio) -> { - var format = readFormats.get(f); + var format = supportedFormats.get(f); if (format == null) return null; return format.fieldsReader(state, dio); }); @@ -63,13 +74,6 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException @Override public String toString() { - return getName() - + "(name=" - + getName() - + ", writeFlatVectorFormat=" - + writeFlatVectorsFormat() - + ", readFlatVectorsFormats=" - + supportedReadFlatVectorsFormats().values() - + ")"; + return getName() + "(name=" + getName() + ", writeFlatVectorFormat=" + writeFormat + ")"; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index 579c42edc6288..c9cbe015c063e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -53,7 +53,7 @@ public ES93HnswBinaryQuantizedVectorsFormat() { */ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useDirectIO, boolean useBFloat16) { super(NAME, maxConn, beamWidth); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO, useBFloat16); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); } /** @@ -76,7 +76,7 @@ public ES93HnswBinaryQuantizedVectorsFormat( ExecutorService mergeExec ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO, useBFloat16); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java index 2038cb5232666..42f02d2d21366 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java @@ -21,7 +21,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; -import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; @@ -36,7 +35,7 @@ import java.io.IOException; -abstract class OffHeapBFloat16VectorValues extends FloatVectorValues implements HasIndexSlice { +abstract class OffHeapBFloat16VectorValues extends FloatVectorValues { protected final int dimension; protected final int size; @@ -62,7 +61,7 @@ abstract class OffHeapBFloat16VectorValues extends FloatVectorValues implements this.byteSize = byteSize; this.similarityFunction = similarityFunction; this.flatVectorsScorer = flatVectorsScorer; - bfloatBytes = new byte[dimension * 2]; + bfloatBytes = new byte[dimension * BFloat16.BYTES]; value = new float[dimension]; } @@ -76,11 +75,6 @@ public int size() { return size; } - @Override - public IndexInput getSlice() { - return slice; - } - @Override public float[] vectorValue(int targetOrd) throws IOException { if (lastOrd == targetOrd) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c40986be0ffbf..33d99c5628732 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2166,8 +2166,7 @@ KnnVectorsFormat getVectorsFormat(ElementType elementType) { return new ES920DiskBBQVectorsFormat( clusterSize, ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - onDiskRescore, - false + onDiskRescore ); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java deleted file mode 100644 index 48b44df9732fd..0000000000000 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.script.field.vectors; - -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.index.codec.vectors.BFloat16; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.ShortBuffer; -import java.util.Iterator; - -public class BFloat16RankVectorsDocValuesField extends RankVectorsDocValuesField { - - private final BinaryDocValues input; - private final BinaryDocValues magnitudes; - private boolean decoded; - private final int dims; - private BytesRef value; - private BytesRef magnitudesValue; - private BFloat16VectorIterator vectorValues; - private int numVectors; - private float[] buffer; - - public BFloat16RankVectorsDocValuesField( - BinaryDocValues input, - BinaryDocValues magnitudes, - String name, - ElementType elementType, - int dims - ) { - super(name, elementType); - this.input = input; - this.magnitudes = magnitudes; - this.dims = dims; - this.buffer = new float[dims]; - } - - @Override - public void setNextDocId(int docId) throws IOException { - decoded = false; - if (input.advanceExact(docId)) { - boolean magnitudesFound = magnitudes.advanceExact(docId); - assert magnitudesFound; - - value = input.binaryValue(); - assert value.length % (BFloat16.BYTES * dims) == 0; - numVectors = value.length / (BFloat16.BYTES * dims); - magnitudesValue = magnitudes.binaryValue(); - assert magnitudesValue.length == (Float.BYTES * numVectors); - } else { - value = null; - magnitudesValue = null; - numVectors = 0; - } - } - - @Override - public RankVectorsScriptDocValues toScriptDocValues() { - return new RankVectorsScriptDocValues(this, dims); - } - - @Override - public boolean isEmpty() { - return value == null; - } - - @Override - public RankVectors get() { - if (isEmpty()) { - return RankVectors.EMPTY; - } - decodeVectorIfNecessary(); - return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); - } - - @Override - public RankVectors get(RankVectors defaultValue) { - if (isEmpty()) { - return defaultValue; - } - decodeVectorIfNecessary(); - return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); - } - - @Override - public RankVectors getInternal() { - return get(null); - } - - @Override - public int size() { - return value == null ? 0 : value.length / (BFloat16.BYTES * dims); - } - - private void decodeVectorIfNecessary() { - if (decoded == false && value != null) { - vectorValues = new BFloat16VectorIterator(value, buffer, numVectors); - decoded = true; - } - } - - public static class BFloat16VectorIterator implements VectorIterator { - private final float[] buffer; - private final ShortBuffer vectorValues; - private final BytesRef vectorValueBytesRef; - private final int size; - private int idx = 0; - - public BFloat16VectorIterator(BytesRef vectorValues, float[] buffer, int size) { - assert vectorValues.length == (buffer.length * BFloat16.BYTES * size); - this.vectorValueBytesRef = vectorValues; - this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length) - .order(ByteOrder.LITTLE_ENDIAN) - .asShortBuffer(); - this.size = size; - this.buffer = buffer; - } - - @Override - public boolean hasNext() { - return idx < size; - } - - @Override - public float[] next() { - if (hasNext() == false) { - throw new IllegalArgumentException("No more elements in the iterator"); - } - BFloat16.bFloat16ToFloat(vectorValues, buffer); - idx++; - return buffer; - } - - @Override - public Iterator copy() { - return new BFloat16VectorIterator(vectorValueBytesRef, new float[buffer.length], size); - } - - @Override - public void reset() { - idx = 0; - vectorValues.rewind(); - } - } -} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java deleted file mode 100644 index 38548deff5b45..0000000000000 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.index.codec.vectors.diskbbq; - -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import static org.hamcrest.Matchers.closeTo; - -public class ES920DiskBBQBFloat16VectorsFormatTests extends ES920DiskBBQVectorsFormatTests { - @Override - boolean useBFloat16() { - return true; - } - - @Override - public void testEmptyByteVectorData() throws Exception { - // no bytes - } - - @Override - public void testMergingWithDifferentByteKnnFields() throws Exception { - // no bytes - } - - @Override - public void testByteVectorScorerIteration() throws Exception { - // no bytes - } - - @Override - public void testSortedIndexBytes() throws Exception { - // no bytes - } - - @Override - public void testMismatchedFields() throws Exception { - // no bytes - } - - @Override - public void testRandomBytes() throws Exception { - // no bytes - } - - @Override - public void testWriterRamEstimate() throws Exception { - // estimate is different due to bfloat16 - } - - @Override - public void testRandom() throws Exception { - AssertionError err = expectThrows(AssertionError.class, super::testRandom); - assertFloatsWithinBounds(err); - } - - @Override - public void testSparseVectors() throws Exception { - AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); - assertFloatsWithinBounds(err); - } - - @Override - public void testVectorValuesReportCorrectDocs() throws Exception { - AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); - assertFloatsWithinBounds(err); - } - - @Override - public void testRandomWithUpdatesAndGraph() throws Exception { - AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); - assertFloatsWithinBounds(err); - } - - private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); - - private static void assertFloatsWithinBounds(AssertionError error) { - Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); - if (m.matches() == false) { - throw error; // nothing to do with us, just rethrow - } - - // numbers just need to be in the same vicinity - double expected = Double.parseDouble(m.group(1)); - double actual = Double.parseDouble(m.group(2)); - double allowedError = expected * 0.01; // within 1% - assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); - } -} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index 1535f71e3ba54..29e6c59d995be 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -63,10 +63,6 @@ public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase private KnnVectorsFormat format; - boolean useBFloat16() { - return false; - } - @Before @Override public void setUp() throws Exception { @@ -74,16 +70,14 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), - random().nextBoolean(), - useBFloat16() + random().nextBoolean() ); } else { // run with low numbers to force many clusters with parents format = new ES920DiskBBQVectorsFormat( random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), - random().nextBoolean(), - useBFloat16() + random().nextBoolean() ); } super.setUp(); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java index 20689e773ee79..1fd99df020ebf 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -84,7 +84,7 @@ boolean useBFloat16() { @Override public void setUp() throws Exception { - format = new ES93BinaryQuantizedVectorsFormat(random().nextBoolean(), useBFloat16()); + format = new ES93BinaryQuantizedVectorsFormat(useBFloat16(), random().nextBoolean()); super.setUp(); } From 69f52cbc28a85c20a1448fc43ace2422b495a7f0 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 17:02:37 +0100 Subject: [PATCH 09/46] Turn the generic format into a proper format --- server/src/main/java/module-info.java | 1 + .../ES93BinaryQuantizedVectorsFormat.java | 17 ++++++++--- .../es93/ES93GenericFlatVectorsFormat.java | 30 ++++++++++++------- .../org.apache.lucene.codecs.KnnVectorsFormat | 1 + ...ES93BinaryQuantizedVectorsFormatTests.java | 16 +++++----- ...HnswBinaryQuantizedVectorsFormatTests.java | 18 +++++------ 6 files changed, 51 insertions(+), 32 deletions(-) diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 2987b3849e663..0ac1ec1fbb612 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -464,6 +464,7 @@ org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat, org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat, + org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java index 7104c242ba1cb..2535784bd1004 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java @@ -25,6 +25,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; @@ -85,7 +86,7 @@ *
  • The sparse vector information, if required, mapping vector ordinal to doc ID * */ -public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsFormat { +public class ES93BinaryQuantizedVectorsFormat extends AbstractFlatVectorsFormat { public static final String NAME = "ES93BinaryQuantizedVectorsFormat"; @@ -93,12 +94,15 @@ public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsForm FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); + private final ES93GenericFlatVectorsFormat rawFormat; + public ES93BinaryQuantizedVectorsFormat() { this(false, false); } public ES93BinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) { - super(NAME, useBFloat16, useDirectIO); + super(NAME); + rawFormat = new ES93GenericFlatVectorsFormat(useBFloat16, useDirectIO); } @Override @@ -108,11 +112,16 @@ protected FlatVectorsScorer flatVectorsScorer() { @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ES818BinaryQuantizedVectorsWriter(scorer, super.fieldsWriter(state), state); + return new ES818BinaryQuantizedVectorsWriter(scorer, rawFormat.fieldsWriter(state), state); } @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new ES818BinaryQuantizedVectorsReader(state, super.fieldsReader(state), scorer); + return new ES818BinaryQuantizedVectorsReader(state, rawFormat.fieldsReader(state), scorer); + } + + @Override + public String toString() { + return getName() + "(name=" + getName() + ", rawVectorFormat=" + rawFormat + ", scorer=" + scorer + ")"; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index e70c3d16d1f26..e2026e24506e7 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -11,6 +11,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -20,8 +21,9 @@ import java.io.IOException; import java.util.Map; -public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { +public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { + static final String NAME = "ES93GenericFlatVectorsFormat"; static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi"; static final String META_CODEC_NAME = "ES93GenericFlatVectorsFormatMeta"; @@ -35,12 +37,11 @@ public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFo VERSION_CURRENT ); - private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( - FlatVectorScorerUtil.getLucene99FlatVectorsScorer() - ); - private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( - FlatVectorScorerUtil.getLucene99FlatVectorsScorer() - ); + private static final FlatVectorsScorer scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer(); + + private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(scorer); + // TODO: a separate scorer for bfloat16 + private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat(scorer); private static final Map supportedFormats = Map.of( float32VectorFormat.getName(), @@ -52,12 +53,21 @@ public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFo private final DirectIOCapableFlatVectorsFormat writeFormat; private final boolean useDirectIO; - public ES93GenericFlatVectorsFormat(String name, boolean useBFloat16, boolean useDirectIO) { - super(name); + public ES93GenericFlatVectorsFormat() { + this(false, false); + } + + public ES93GenericFlatVectorsFormat(boolean useBFloat16, boolean useDirectIO) { + super(NAME); writeFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; this.useDirectIO = useDirectIO; } + @Override + protected FlatVectorsScorer flatVectorsScorer() { + return scorer; + } + @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new ES93GenericFlatVectorsWriter(META, writeFormat.getName(), useDirectIO, state, writeFormat.fieldsWriter(state)); @@ -74,6 +84,6 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException @Override public String toString() { - return getName() + "(name=" + getName() + ", writeFlatVectorFormat=" + writeFormat + ")"; + return getName() + "(name=" + getName() + ", format=" + writeFormat + ")"; } } diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 6c21437d71d28..bf96d9c2de886 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -9,5 +9,6 @@ org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat +org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java index 1fd99df020ebf..108739533a76b 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -62,12 +62,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Locale; -import static java.lang.String.format; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; -import static org.hamcrest.Matchers.either; -import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.Matchers.oneOf; public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { @@ -196,11 +193,12 @@ public KnnVectorsFormat knnVectorsFormat() { } }; String expectedPattern = "ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat," - + " writeFlatVectorFormat=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat," - + " flatVectorScorer=%s())"; - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); - var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); - assertThat(customCodec.knnVectorsFormat().toString(), either(startsWith(defaultScorer)).or(startsWith(memSegScorer))); + + " rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + + " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer={}()))," + + " scorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate={}()))"; + var defaultScorer = expectedPattern.replaceAll("\\{}", "DefaultFlatVectorScorer"); + var memSegScorer = expectedPattern.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), oneOf(defaultScorer, memSegScorer)); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index c5d85b5cc4681..0500423b0e9fc 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -50,15 +50,13 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Locale; import static java.lang.String.format; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.hamcrest.Matchers.either; -import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.Matchers.oneOf; public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { @@ -91,14 +89,16 @@ public KnnVectorsFormat knnVectorsFormat() { return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, false, false, 1, null); } }; - String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," + String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat," + + " maxConn=10, beamWidth=20," + " flatVectorFormat=ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat," - + " writeFlatVectorFormat=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat," - + " flatVectorScorer=%s())"; + + " rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + + " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer={}()))," + + " scorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate={}())))"; - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); - var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); - assertThat(customCodec.knnVectorsFormat().toString(), either(startsWith(defaultScorer)).or(startsWith(memSegScorer))); + var defaultScorer = expectedPattern.replaceAll("\\{}", "DefaultFlatVectorScorer"); + var memSegScorer = expectedPattern.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), oneOf(defaultScorer, memSegScorer)); } public void testSingleVectorCase() throws Exception { From 5ce914db9e04f13e9ed2ce36897ea92f35857139 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 17:10:56 +0100 Subject: [PATCH 10/46] More propagation --- .../diskbbq/ES920DiskBBQVectorsFormat.java | 17 +++- .../es93/ES93GenericFlatVectorsFormat.java | 1 + .../vectors/DenseVectorFieldMapper.java | 9 +- ...S920DiskBBQBFloat16VectorsFormatTests.java | 96 +++++++++++++++++++ .../ES920DiskBBQVectorsFormatTests.java | 6 ++ 5 files changed, 120 insertions(+), 9 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index 99bc9a9d7bdb2..086ab06adff0f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -18,6 +18,7 @@ import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es93.DirectIOCapableLucene99FlatVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat; import java.io.IOException; import java.util.Map; @@ -61,9 +62,14 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); + private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); private static final Map supportedFormats = Map.of( float32VectorFormat.getName(), - float32VectorFormat + float32VectorFormat, + bfloat16VectorFormat.getName(), + bfloat16VectorFormat ); // This dynamically sets the cluster probe based on the `k` requested and the number of clusters. @@ -78,14 +84,15 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private final int vectorPerCluster; private final int centroidsPerParentCluster; - private final boolean useDirectIO; private final DirectIOCapableFlatVectorsFormat rawVectorFormat; + private final boolean useDirectIO; public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { - this(vectorPerCluster, centroidsPerParentCluster, false); + this(vectorPerCluster, centroidsPerParentCluster, false, false); } - public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO) { + // TODO: ElementType + public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useBFloat16, boolean useDirectIO) { super(NAME); if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { throw new IllegalArgumentException( @@ -109,8 +116,8 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu } this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; + this.rawVectorFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; this.useDirectIO = useDirectIO; - this.rawVectorFormat = float32VectorFormat; } /** Constructs a format using the given graph construction parameters and scalar quantization. */ diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index e2026e24506e7..d2b71f1c63353 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -57,6 +57,7 @@ public ES93GenericFlatVectorsFormat() { this(false, false); } + // TODO: ElementType public ES93GenericFlatVectorsFormat(boolean useBFloat16, boolean useDirectIO) { super(NAME); writeFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 1a506a5e94cdf..f00dcb15dbe69 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -130,6 +130,7 @@ public class DenseVectorFieldMapper extends FieldMapper { private static final boolean DEFAULT_HNSW_EARLY_TERMINATION = false; public static boolean isNotUnitVector(float magnitude) { + // TODO: need different EPS for bfloat16? return Math.abs(magnitude - 1.0f) > EPS; } @@ -2213,14 +2214,14 @@ KnnVectorsFormat getVectorsFormat(ElementType elementType) { case FLOAT -> new ES920DiskBBQVectorsFormat( clusterSize, ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - onDiskRescore, - false + false, + onDiskRescore ); case BFLOAT16 -> new ES920DiskBBQVectorsFormat( clusterSize, ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - onDiskRescore, - true + true, + onDiskRescore ); default -> throw new AssertionError(); }; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..38548deff5b45 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.closeTo; + +public class ES920DiskBBQBFloat16VectorsFormatTests extends ES920DiskBBQVectorsFormatTests { + @Override + boolean useBFloat16() { + return true; + } + + @Override + public void testEmptyByteVectorData() throws Exception { + // no bytes + } + + @Override + public void testMergingWithDifferentByteKnnFields() throws Exception { + // no bytes + } + + @Override + public void testByteVectorScorerIteration() throws Exception { + // no bytes + } + + @Override + public void testSortedIndexBytes() throws Exception { + // no bytes + } + + @Override + public void testMismatchedFields() throws Exception { + // no bytes + } + + @Override + public void testRandomBytes() throws Exception { + // no bytes + } + + @Override + public void testWriterRamEstimate() throws Exception { + // estimate is different due to bfloat16 + } + + @Override + public void testRandom() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandom); + assertFloatsWithinBounds(err); + } + + @Override + public void testSparseVectors() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); + assertFloatsWithinBounds(err); + } + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); + assertFloatsWithinBounds(err); + } + + @Override + public void testRandomWithUpdatesAndGraph() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); + assertFloatsWithinBounds(err); + } + + private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); + + private static void assertFloatsWithinBounds(AssertionError error) { + Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); + if (m.matches() == false) { + throw error; // nothing to do with us, just rethrow + } + + // numbers just need to be in the same vicinity + double expected = Double.parseDouble(m.group(1)); + double actual = Double.parseDouble(m.group(2)); + double allowedError = expected * 0.01; // within 1% + assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index 29e6c59d995be..5d889ed9e6061 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -63,6 +63,10 @@ public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase private KnnVectorsFormat format; + boolean useBFloat16() { + return false; + } + @Before @Override public void setUp() throws Exception { @@ -70,6 +74,7 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), + useBFloat16(), random().nextBoolean() ); } else { @@ -77,6 +82,7 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), + useBFloat16(), random().nextBoolean() ); } From 2e815fe9ae74d15c5186ee9885d0e7f5bbeccd0b Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 10 Oct 2025 17:19:19 +0100 Subject: [PATCH 11/46] Add a basic generic HNSW implementation --- .../vectors/es93/ES93HnswVectorsFormat.java | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java new file mode 100644 index 0000000000000..f1fa528b391af --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; + +import java.io.IOException; + +public class ES93HnswVectorsFormat extends AbstractHnswVectorsFormat { + + static final String NAME = "ES93HnswVectorsFormat"; + + private final FlatVectorsFormat flatVectorsFormat; + + public ES93HnswVectorsFormat() { + super(NAME); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(); + } + + public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO) { + super(NAME, maxConn, beamWidth); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); + } + + @Override + protected FlatVectorsFormat flatVectorsFormat() { + return flatVectorsFormat; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } +} From 9c81e3341a6ab6160d2a67a5d1eb4531e95b6477 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 13 Oct 2025 11:03:24 +0100 Subject: [PATCH 12/46] Relax unit vector check for bfloat16 --- .../org/elasticsearch/index/codec/vectors/BQVectorUtils.java | 3 ++- .../index/codec/vectors/OptimizedScalarQuantizer.java | 2 +- .../codec/vectors/es818/ES818BinaryFlatVectorsScorer.java | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java index cba55f8a7e942..c200828876e85 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java @@ -26,7 +26,8 @@ /** Utility class for vector quantization calculations */ public class BQVectorUtils { - private static final float EPSILON = 1e-4f; + // NOTE: this is currently > 1e-4f due to bfloat16 + private static final float EPSILON = 1e-2f; public static double sqrtNewtonRaphson(double x, double curr, double prev) { return (curr == prev) ? curr : sqrtNewtonRaphson(x, 0.5 * (curr + x / curr), curr); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java index 293cb61e9105c..eac3e708dfe66 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java @@ -109,7 +109,7 @@ public QuantizationResult[] multiScalarQuantize( } public QuantizationResult scalarQuantize(float[] vector, float[] residualDestination, int[] destination, byte bits, float[] centroid) { - assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || BQVectorUtils.isUnitVector(vector); assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); assert vector.length <= destination.length; assert bits > 0 && bits <= 8; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index a4fcf71cf94be..fca247f084adf 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -27,11 +27,13 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; @@ -69,6 +71,7 @@ public RandomVectorScorer getRandomVectorScorer( assert binarizedVectors.size() > 0 : "BinarizedByteVectorValues must have at least one vector for ES816BinaryFlatVectorsScorer"; OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer(); float[] centroid = binarizedVectors.getCentroid(); + assert similarityFunction != COSINE || BQVectorUtils.isUnitVector(target); float[] scratch = new float[vectorValues.dimension()]; int[] initial = new int[target.length]; byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; From ce690ffedc10b8c1fc5f3a4df312a5fa9676b61a Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 13 Oct 2025 11:51:28 +0100 Subject: [PATCH 13/46] Add tests for ES93HnswVectorsFormat --- server/src/main/java/module-info.java | 1 + .../vectors/es93/ES93HnswVectorsFormat.java | 6 + .../org.apache.lucene.codecs.KnnVectorsFormat | 1 + .../ES93HnswBFloat16VectorsFormatTests.java | 99 ++++++++++++++ .../es93/ES93HnswVectorsFormatTests.java | 121 ++++++++++++++++++ 5 files changed, 228 insertions(+) create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 0ac1ec1fbb612..b3109d08b5deb 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -466,6 +466,7 @@ org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; provides org.apache.lucene.codecs.Codec diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java index f1fa528b391af..1e3fb9893e291 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java @@ -19,6 +19,7 @@ import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; import java.io.IOException; +import java.util.concurrent.ExecutorService; public class ES93HnswVectorsFormat extends AbstractHnswVectorsFormat { @@ -36,6 +37,11 @@ public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boole flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); } + public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); + } + @Override protected FlatVectorsFormat flatVectorsFormat() { return flatVectorsFormat; diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index bf96d9c2de886..cd2fa7c3fdf9b 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -11,4 +11,5 @@ org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..f6b3a2aedca57 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.index.VectorEncoding; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.closeTo; + +public class ES93HnswBFloat16VectorsFormatTests extends ES93HnswVectorsFormatTests { + + @Override + protected boolean useBFloat16() { + return true; + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + public void testEmptyByteVectorData() throws Exception { + // no bytes + } + + @Override + public void testMergingWithDifferentByteKnnFields() throws Exception { + // no bytes + } + + @Override + public void testByteVectorScorerIteration() throws Exception { + // no bytes + } + + @Override + public void testSortedIndexBytes() throws Exception { + // no bytes + } + + @Override + public void testMismatchedFields() throws Exception { + // no bytes + } + + @Override + public void testRandomBytes() throws Exception { + // no bytes + } + + @Override + public void testRandom() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandom); + assertFloatsWithinBounds(err); + } + + @Override + public void testRandomWithUpdatesAndGraph() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); + assertFloatsWithinBounds(err); + } + + @Override + public void testSparseVectors() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); + assertFloatsWithinBounds(err); + } + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); + assertFloatsWithinBounds(err); + } + + private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); + + private static void assertFloatsWithinBounds(AssertionError error) { + Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); + if (m.matches() == false) { + throw error; // nothing to do with us, just rethrow + } + + // numbers just need to be in the same vicinity + double expected = Double.parseDouble(m.group(1)); + double actual = Double.parseDouble(m.group(2)); + double allowedError = expected * 0.01; // within 1% + assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java new file mode 100644 index 0000000000000..c6b7e4df978cc --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.SameThreadExecutorService; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BFloat16; + +import java.io.IOException; +import java.util.Locale; + +import static java.lang.String.format; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES93HnswVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + private KnnVectorsFormat format; + + protected boolean useBFloat16() { + return false; + } + + @Override + public void setUp() throws Exception { + format = new ES93HnswVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, useBFloat16(), random().nextBoolean()); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES93HnswVectorsFormat(10, 20, false, false); + } + }; + String expectedPattern = + "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + + " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=%s())))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = + format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(-1, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(0, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 0, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, -1, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(512 + 1, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 3201, false, false)); + expectThrows( + IllegalArgumentException.class, + () -> new ES93HnswVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService())); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; + assertEquals(vector.length * bytes, (long) offHeap.get("vec")); + assertEquals(1L, (long) offHeap.get("vex")); + assertEquals(2, offHeap.size()); + } + } + } + } +} From 1de5bd84b5907484b4753a4a6dc4ae2f41f4740d Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 13 Oct 2025 11:52:59 +0100 Subject: [PATCH 14/46] Remove unquantized format --- .../vectors/es93/ES93HnswVectorsFormat.java | 53 ------------------- 1 file changed, 53 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java deleted file mode 100644 index f1fa528b391af..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.index.codec.vectors.es93; - -import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; -import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.SegmentWriteState; -import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; - -import java.io.IOException; - -public class ES93HnswVectorsFormat extends AbstractHnswVectorsFormat { - - static final String NAME = "ES93HnswVectorsFormat"; - - private final FlatVectorsFormat flatVectorsFormat; - - public ES93HnswVectorsFormat() { - super(NAME); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(); - } - - public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO) { - super(NAME, maxConn, beamWidth); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); - } - - @Override - protected FlatVectorsFormat flatVectorsFormat() { - return flatVectorsFormat; - } - - @Override - public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); - } - - @Override - public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); - } -} From e54d4c6ebec38579e4a559a70d4094347c2d3f18 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 13 Oct 2025 11:56:16 +0100 Subject: [PATCH 15/46] Consolidate parameter order --- server/src/main/java/module-info.java | 1 - .../vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java | 6 +++--- .../services/org.apache.lucene.codecs.KnnVectorsFormat | 1 - .../es93/ES93HnswBinaryQuantizedVectorsFormatTests.java | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 0ac1ec1fbb612..2987b3849e663 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -464,7 +464,6 @@ org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat, org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat, - org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index c9cbe015c063e..c42701f1e5d6f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -47,11 +47,11 @@ public ES93HnswBinaryQuantizedVectorsFormat() { /** * Constructs a format using the given graph construction parameters. * - * @param maxConn the maximum number of connections to a node in the HNSW graph - * @param beamWidth the size of the queue maintained during graph construction. + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. * @param useDirectIO whether to use direct IO when reading raw vectors */ - public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useDirectIO, boolean useBFloat16) { + public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useBFloat16, boolean useDirectIO) { super(NAME, maxConn, beamWidth); flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); } diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index bf96d9c2de886..6c21437d71d28 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -9,6 +9,5 @@ org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat -org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index 0500423b0e9fc..45e489662f3bf 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -73,7 +73,7 @@ boolean useBFloat16() { @Override public void setUp() throws Exception { - format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, random().nextBoolean(), useBFloat16()); + format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, useBFloat16(), random().nextBoolean()); super.setUp(); } From 80d980943984c16f8ad6f0b12ee894a4d085da6c Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 13 Oct 2025 16:48:48 +0100 Subject: [PATCH 16/46] Use the correct values --- .../es93/ES93BFloat16FlatVectorsWriter.java | 3 +- .../vectors/es93/ES93HnswVectorsFormat.java | 9 +++++- .../es93/ES93HnswVectorsFormatTests.java | 30 ++++++++----------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java index 3c143d94fd6b5..86894377bacde 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java @@ -23,7 +23,6 @@ import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -250,7 +249,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldI final IndexInput finalVectorDataInput = vectorDataInput; final RandomVectorScorerSupplier randomVectorScorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - new OffHeapFloatVectorValues.DenseOffHeapVectorValues( + new OffHeapBFloat16VectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), finalVectorDataInput, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java index 1e3fb9893e291..ad151147d87ea 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java @@ -37,7 +37,14 @@ public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boole flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); } - public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO, int numMergeWorkers, ExecutorService mergeExec) { + public ES93HnswVectorsFormat( + int maxConn, + int beamWidth, + boolean bfloat16, + boolean useDirectIO, + int numMergeWorkers, + ExecutorService mergeExec + ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java index c6b7e4df978cc..bd7e4f7f653bf 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java @@ -13,7 +13,6 @@ import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; @@ -64,20 +63,17 @@ protected Codec getCodec() { } public void testToString() { - FilterCodec customCodec = - new FilterCodec("foo", Codec.getDefault()) { - @Override - public KnnVectorsFormat knnVectorsFormat() { - return new ES93HnswVectorsFormat(10, 20, false, false); - } - }; - String expectedPattern = - "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20," + - " flatVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + - " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=%s())))"; + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES93HnswVectorsFormat(10, 20, false, false); + } + }; + String expectedPattern = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + + " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=%s())))"; var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); - var memSegScorer = - format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); } @@ -90,13 +86,13 @@ public void testLimits() { expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 3201, false, false)); expectThrows( IllegalArgumentException.class, - () -> new ES93HnswVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService())); + () -> new ES93HnswVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService()) + ); } public void testSimpleOffHeapSize() throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); - try (Directory dir = newDirectory(); - IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); w.addDocument(doc); From 74bb1303fe29c005769d69dc1d9624b7e9935cc3 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Tue, 14 Oct 2025 13:14:12 +0100 Subject: [PATCH 17/46] Remove unneeded method --- .../es93/DirectIOCapableLucene99FlatVectorsFormat.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java index 87c8909aa8c9b..b67f7186b8f4d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java @@ -64,11 +64,6 @@ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOExceptio return new Lucene99FlatVectorsWriter(state, vectorsScorer); } - @Override - public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return fieldsReader(state, false); - } - @Override public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException { if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) { From 38183b76fa8baa81c50ded25ca19443487fd24a5 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Tue, 14 Oct 2025 15:19:15 +0100 Subject: [PATCH 18/46] Use bfloat16/directio in rescore tests --- .../search/vectors/RescoreKnnVectorQueryTests.java | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index ebe62fa0cba98..7500b902f115c 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -283,10 +283,20 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th IndexWriterConfig iwc = new IndexWriterConfig(); // Pick codec from quantized vector formats to ensure scores use real scores when using knn rescore KnnVectorsFormat format = randomFrom( - new ES920DiskBBQVectorsFormat(DEFAULT_VECTORS_PER_CLUSTER, DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, randomBoolean()), + new ES920DiskBBQVectorsFormat( + DEFAULT_VECTORS_PER_CLUSTER, + DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, + randomBoolean(), + randomBoolean() + ), new ES818BinaryQuantizedVectorsFormat(), new ES818HnswBinaryQuantizedVectorsFormat(), - new ES93HnswBinaryQuantizedVectorsFormat(), + new ES93HnswBinaryQuantizedVectorsFormat( + DEFAULT_VECTORS_PER_CLUSTER, + DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, + randomBoolean(), + randomBoolean() + ), new ES813Int8FlatVectorFormat(), new ES813Int8FlatVectorFormat(), new ES814HnswScalarQuantizedVectorsFormat() From 5277641fae4686a997cbf12d9e911cd7a00b5809 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Tue, 14 Oct 2025 17:24:37 +0100 Subject: [PATCH 19/46] Test updates --- .../test/search.vectors/41_knn_search_bbq_hnsw.yml | 2 +- .../41_knn_search_bbq_hnsw_bfloat16.yml | 2 +- .../42_knn_search_bbq_flat_bfloat16.yml | 2 +- .../elasticsearch/index/mapper/MapperFeatures.java | 4 ++-- .../mapper/vectors/DenseVectorFieldMapperTests.java | 3 ++- .../upgrades/SemanticTextUpgradeIT.java | 13 +++++++++++++ 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index b58ae2a29839a..db9a76f54ab8b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -340,7 +340,7 @@ setup: --- "Test index configured rescore vector with on-disk rescoring": - requires: - cluster_features: ["mapper.vectors.bbq_hnsw_on_disk_rescoring"] + cluster_features: ["mapper.vectors.hnsw_on_disk_rescoring"] reason: Needs on_disk_rescoring feature - skip: features: "headers" diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml index 980d9ba924fea..d348d40dd493b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -1,6 +1,6 @@ setup: - requires: - cluster_features: "mapper.vectors.bbq_hnsw_on_disk_rescoring" + cluster_features: "mapper.vectors.hnsw_on_disk_rescoring" reason: 'bfloat16 needs to be supported' - do: indices.create: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml index 2b801e92d7b7c..79258864f09a4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -1,6 +1,6 @@ setup: - requires: - cluster_features: "mapper.vectors.bbq_hnsw_on_disk_rescoring" + cluster_features: "mapper.vectors.hnsw_on_disk_rescoring" reason: 'bfloat16 needs to be supported' - do: indices.create: diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 583c0734acc94..0beb4a8661dd8 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -57,7 +57,7 @@ public class MapperFeatures implements FeatureSpecification { static final NodeFeature INDEX_MAPPING_IGNORE_DYNAMIC_BEYOND_FIELD_NAME_LIMIT = new NodeFeature( "mapper.ignore_dynamic_field_names_beyond_limit" ); - static final NodeFeature BBQ_HNSW_ON_DISK_RESCORING = new NodeFeature("mapper.vectors.bbq_hnsw_on_disk_rescoring"); + public static final NodeFeature HNSW_ON_DISK_RESCORING = new NodeFeature("mapper.vectors.hnsw_on_disk_rescoring"); @Override public Set getTestFeatures() { @@ -99,7 +99,7 @@ public Set getTestFeatures() { DISKBBQ_ON_DISK_RESCORING, PROVIDE_INDEX_SORT_SETTING_DEFAULTS, INDEX_MAPPING_IGNORE_DYNAMIC_BEYOND_FIELD_NAME_LIMIT, - BBQ_HNSW_ON_DISK_RESCORING + HNSW_ON_DISK_RESCORING ); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 436136fe526da..0d2016dea1877 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -2904,7 +2904,8 @@ public void testKnnBBQHNSWVectorsFormat() throws IOException { + efConstruction + ", flatVectorFormat=ES93BinaryQuantizedVectorsFormat(" + "name=ES93BinaryQuantizedVectorsFormat, " - + "writeFlatVectorFormat=Lucene99FlatVectorsFormat"; + + "rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + + " format=Lucene99FlatVectorsFormat"; assertThat(knnVectorsFormat, hasToString(startsWith(expectedString))); } diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java index 9dcf5537ee9d5..cc18b9e4f23a8 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.mapper.MapperFeatures; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; import org.elasticsearch.index.query.NestedQueryBuilder; @@ -34,6 +35,7 @@ import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; +import org.junit.Before; import org.junit.BeforeClass; import java.io.IOException; @@ -83,6 +85,17 @@ public static Iterable parameters() { return List.of(new Object[] { true }, new Object[] { false }); } + private boolean runTest; + + @Before + public void checkSupport() { + if (CLUSTER_TYPE == ClusterType.OLD) { + runTest = DENSE_MODEL.getServiceSettings().elementType() != DenseVectorFieldMapper.ElementType.BFLOAT16 + || clusterHasFeature(MapperFeatures.HNSW_ON_DISK_RESCORING); + } + assumeTrue("Old cluster needs to support bfloat16", runTest); + } + public void testSemanticTextOperations() throws Exception { switch (CLUSTER_TYPE) { case OLD -> createAndPopulateIndex(); From 201ed46f44bca75ef88ddb83de667acefe614927 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 15 Oct 2025 10:48:45 +0100 Subject: [PATCH 20/46] Get parameters the right way round --- .../test/search.vectors/41_knn_search_bbq_hnsw.yml | 2 +- .../test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml | 2 +- .../test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml | 2 +- .../es93/DirectIOCapableLucene99FlatVectorsFormat.java | 5 +++++ .../java/org/elasticsearch/index/mapper/MapperFeatures.java | 4 ++-- .../index/mapper/vectors/DenseVectorFieldMapper.java | 4 ++-- .../org/elasticsearch/upgrades/SemanticTextUpgradeIT.java | 2 +- 7 files changed, 13 insertions(+), 8 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index db9a76f54ab8b..0a2c9eb31dc72 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -340,7 +340,7 @@ setup: --- "Test index configured rescore vector with on-disk rescoring": - requires: - cluster_features: ["mapper.vectors.hnsw_on_disk_rescoring"] + cluster_features: ["mapper.vectors.hnsw_bfloat16_on_disk_rescoring"] reason: Needs on_disk_rescoring feature - skip: features: "headers" diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml index d348d40dd493b..d873158e637c3 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -1,6 +1,6 @@ setup: - requires: - cluster_features: "mapper.vectors.hnsw_on_disk_rescoring" + cluster_features: "mapper.vectors.hnsw_bfloat16_on_disk_rescoring" reason: 'bfloat16 needs to be supported' - do: indices.create: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml index 79258864f09a4..80fee2c53468f 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -1,6 +1,6 @@ setup: - requires: - cluster_features: "mapper.vectors.hnsw_on_disk_rescoring" + cluster_features: "mapper.vectors.hnsw_bfloat16_on_disk_rescoring" reason: 'bfloat16 needs to be supported' - do: indices.create: diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java index b67f7186b8f4d..fed2bddf6ede6 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java @@ -172,6 +172,11 @@ public int size() { return inner.size(); } + @Override + public DocIndexIterator iterator() { + return inner.iterator(); + } + @Override public RescorerOffHeapVectorValues copy() throws IOException { return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 0beb4a8661dd8..92900b0831a6e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -57,7 +57,7 @@ public class MapperFeatures implements FeatureSpecification { static final NodeFeature INDEX_MAPPING_IGNORE_DYNAMIC_BEYOND_FIELD_NAME_LIMIT = new NodeFeature( "mapper.ignore_dynamic_field_names_beyond_limit" ); - public static final NodeFeature HNSW_ON_DISK_RESCORING = new NodeFeature("mapper.vectors.hnsw_on_disk_rescoring"); + public static final NodeFeature HNSW_BFLOAT16_ON_DISK_RESCORING = new NodeFeature("mapper.vectors.hnsw_bfloat16_on_disk_rescoring"); @Override public Set getTestFeatures() { @@ -99,7 +99,7 @@ public Set getTestFeatures() { DISKBBQ_ON_DISK_RESCORING, PROVIDE_INDEX_SORT_SETTING_DEFAULTS, INDEX_MAPPING_IGNORE_DYNAMIC_BEYOND_FIELD_NAME_LIMIT, - HNSW_ON_DISK_RESCORING + HNSW_BFLOAT16_ON_DISK_RESCORING ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index f00dcb15dbe69..912782d3ed21b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2075,8 +2075,8 @@ public BBQHnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, Res @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { return switch (elementType) { - case FLOAT -> new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, onDiskRescore, false); - case BFLOAT16 -> new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, onDiskRescore, true); + case FLOAT -> new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, false, onDiskRescore); + case BFLOAT16 -> new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, true, onDiskRescore); default -> throw new AssertionError(); }; } diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java index cc18b9e4f23a8..22236041a305a 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -91,7 +91,7 @@ public static Iterable parameters() { public void checkSupport() { if (CLUSTER_TYPE == ClusterType.OLD) { runTest = DENSE_MODEL.getServiceSettings().elementType() != DenseVectorFieldMapper.ElementType.BFLOAT16 - || clusterHasFeature(MapperFeatures.HNSW_ON_DISK_RESCORING); + || clusterHasFeature(MapperFeatures.HNSW_BFLOAT16_ON_DISK_RESCORING); } assumeTrue("Old cluster needs to support bfloat16", runTest); } From 5e6ddc984cce0b05bf0d9b5fd9eaf46fafd47ae3 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 15 Oct 2025 11:34:25 +0100 Subject: [PATCH 21/46] Use ElementType --- .../diskbbq/ES920DiskBBQVectorsFormat.java | 16 +++++++-- .../ES93BinaryQuantizedVectorsFormat.java | 7 ++-- .../es93/ES93GenericFlatVectorsFormat.java | 12 ++++--- .../ES93HnswBinaryQuantizedVectorsFormat.java | 14 +++++--- .../vectors/es93/ES93HnswVectorsFormat.java | 9 ++--- .../vectors/DenseVectorFieldMapper.java | 33 +++++-------------- ...S920DiskBBQBFloat16VectorsFormatTests.java | 6 ++-- .../ES920DiskBBQVectorsFormatTests.java | 9 ++--- ...ryQuantizedBFloat16VectorsFormatTests.java | 5 +-- ...ES93BinaryQuantizedVectorsFormatTests.java | 13 +++++--- .../ES93HnswBFloat16VectorsFormatTests.java | 5 +-- ...ryQuantizedBFloat16VectorsFormatTests.java | 5 +-- ...HnswBinaryQuantizedVectorsFormatTests.java | 29 +++++++++------- .../es93/ES93HnswVectorsFormatTests.java | 29 +++++++++------- .../vectors/RescoreKnnVectorQueryTests.java | 5 +-- 15 files changed, 112 insertions(+), 85 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index 086ab06adff0f..20c74653b61dc 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -19,6 +19,7 @@ import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es93.DirectIOCapableLucene99FlatVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Map; @@ -88,11 +89,16 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private final boolean useDirectIO; public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { - this(vectorPerCluster, centroidsPerParentCluster, false, false); + this(vectorPerCluster, centroidsPerParentCluster, DenseVectorFieldMapper.ElementType.FLOAT, false); } // TODO: ElementType - public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useBFloat16, boolean useDirectIO) { + public ES920DiskBBQVectorsFormat( + int vectorPerCluster, + int centroidsPerParentCluster, + DenseVectorFieldMapper.ElementType elementType, + boolean useDirectIO + ) { super(NAME); if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { throw new IllegalArgumentException( @@ -116,7 +122,11 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu } this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; - this.rawVectorFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; + this.rawVectorFormat = switch (elementType) { + case FLOAT -> float32VectorFormat; + case BFLOAT16 -> bfloat16VectorFormat; + default -> throw new IllegalArgumentException("Unsupported element type " + elementType); + }; this.useDirectIO = useDirectIO; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java index 2535784bd1004..81b1cab4af1cd 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java @@ -30,6 +30,7 @@ import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; @@ -97,12 +98,12 @@ public class ES93BinaryQuantizedVectorsFormat extends AbstractFlatVectorsFormat private final ES93GenericFlatVectorsFormat rawFormat; public ES93BinaryQuantizedVectorsFormat() { - this(false, false); + this(DenseVectorFieldMapper.ElementType.FLOAT, false); } - public ES93BinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) { + public ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { super(NAME); - rawFormat = new ES93GenericFlatVectorsFormat(useBFloat16, useDirectIO); + rawFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index d2b71f1c63353..a2aabbf8ec506 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -17,6 +17,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Map; @@ -54,13 +55,16 @@ public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { private final boolean useDirectIO; public ES93GenericFlatVectorsFormat() { - this(false, false); + this(DenseVectorFieldMapper.ElementType.FLOAT, false); } - // TODO: ElementType - public ES93GenericFlatVectorsFormat(boolean useBFloat16, boolean useDirectIO) { + public ES93GenericFlatVectorsFormat(DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { super(NAME); - writeFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat; + writeFormat = switch (elementType) { + case FLOAT -> float32VectorFormat; + case BFLOAT16 -> bfloat16VectorFormat; + default -> throw new IllegalArgumentException("Unsupported element type " + elementType); + }; this.useDirectIO = useDirectIO; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index c42701f1e5d6f..24f8015d57d21 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -27,6 +27,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.concurrent.ExecutorService; @@ -51,9 +52,14 @@ public ES93HnswBinaryQuantizedVectorsFormat() { * @param beamWidth the size of the queue maintained during graph construction. * @param useDirectIO whether to use direct IO when reading raw vectors */ - public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useBFloat16, boolean useDirectIO) { + public ES93HnswBinaryQuantizedVectorsFormat( + int maxConn, + int beamWidth, + DenseVectorFieldMapper.ElementType elementType, + boolean useDirectIO + ) { super(NAME, maxConn, beamWidth); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO); } /** @@ -70,13 +76,13 @@ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean public ES93HnswBinaryQuantizedVectorsFormat( int maxConn, int beamWidth, + DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO, - boolean useBFloat16, int numMergeWorkers, ExecutorService mergeExec ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); - flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java index ad151147d87ea..61ba1f6d3e9fc 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java @@ -17,6 +17,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.concurrent.ExecutorService; @@ -32,21 +33,21 @@ public ES93HnswVectorsFormat() { flatVectorsFormat = new ES93GenericFlatVectorsFormat(); } - public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO) { + public ES93HnswVectorsFormat(int maxConn, int beamWidth, DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { super(NAME, maxConn, beamWidth); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); } public ES93HnswVectorsFormat( int maxConn, int beamWidth, - boolean bfloat16, + DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO, int numMergeWorkers, ExecutorService mergeExec ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 912782d3ed21b..a002c68e1a956 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2074,11 +2074,7 @@ public BBQHnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, Res @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - return switch (elementType) { - case FLOAT -> new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, false, onDiskRescore); - case BFLOAT16 -> new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, true, onDiskRescore); - default -> throw new AssertionError(); - }; + return new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, elementType, onDiskRescore); } @Override @@ -2143,11 +2139,7 @@ static class BBQFlatIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - return switch (elementType) { - case FLOAT -> new ES93BinaryQuantizedVectorsFormat(false, false); - case BFLOAT16 -> new ES93BinaryQuantizedVectorsFormat(false, true); - default -> throw new AssertionError(); - }; + return new ES93BinaryQuantizedVectorsFormat(elementType, false); } @Override @@ -2210,21 +2202,12 @@ static class BBQIVFIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - return switch (elementType) { - case FLOAT -> new ES920DiskBBQVectorsFormat( - clusterSize, - ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - false, - onDiskRescore - ); - case BFLOAT16 -> new ES920DiskBBQVectorsFormat( - clusterSize, - ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - true, - onDiskRescore - ); - default -> throw new AssertionError(); - }; + return new ES920DiskBBQVectorsFormat( + clusterSize, + ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, + elementType, + onDiskRescore + ); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java index 38548deff5b45..f0e5f8369d588 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java @@ -9,6 +9,8 @@ package org.elasticsearch.index.codec.vectors.diskbbq; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -16,8 +18,8 @@ public class ES920DiskBBQBFloat16VectorsFormatTests extends ES920DiskBBQVectorsFormatTests { @Override - boolean useBFloat16() { - return true; + DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.BFLOAT16; } @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index 5d889ed9e6061..be512ba2f5872 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -36,6 +36,7 @@ import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.Before; import java.io.IOException; @@ -63,8 +64,8 @@ public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase private KnnVectorsFormat format; - boolean useBFloat16() { - return false; + DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; } @Before @@ -74,7 +75,7 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), - useBFloat16(), + elementType(), random().nextBoolean() ); } else { @@ -82,7 +83,7 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), - useBFloat16(), + elementType(), random().nextBoolean() ); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java index 9ae394733631e..3c2370989a9e4 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.codec.vectors.es93; import org.apache.lucene.index.VectorEncoding; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -18,8 +19,8 @@ public class ES93BinaryQuantizedBFloat16VectorsFormatTests extends ES93BinaryQuantizedVectorsFormatTests { @Override - boolean useBFloat16() { - return true; + DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.BFLOAT16; } @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java index 108739533a76b..bb29c602c01a4 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -57,6 +57,7 @@ import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.ArrayList; @@ -75,13 +76,13 @@ public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatT private KnnVectorsFormat format; - boolean useBFloat16() { - return false; + DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; } @Override public void setUp() throws Exception { - format = new ES93BinaryQuantizedVectorsFormat(useBFloat16(), random().nextBoolean()); + format = new ES93BinaryQuantizedVectorsFormat(elementType(), random().nextBoolean()); super.setUp(); } @@ -242,7 +243,11 @@ public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, b assertEquals(expectVecOffHeap ? 2 : 1, offHeap.size()); assertTrue(offHeap.get("veb") > 0L); if (expectVecOffHeap) { - int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; + int bytes = switch (elementType()) { + case FLOAT -> Float.BYTES; + case BFLOAT16 -> BFloat16.BYTES; + default -> throw new AssertionError(); + }; assertEquals(vector.length * bytes, (long) offHeap.get("vec")); } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java index f6b3a2aedca57..d72ad12c9b9de 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.codec.vectors.es93; import org.apache.lucene.index.VectorEncoding; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -19,8 +20,8 @@ public class ES93HnswBFloat16VectorsFormatTests extends ES93HnswVectorsFormatTests { @Override - protected boolean useBFloat16() { - return true; + DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.BFLOAT16; } @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java index c6f3e555013b3..493da2585da7e 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.codec.vectors.es93; import org.apache.lucene.index.VectorEncoding; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -19,8 +20,8 @@ public class ES93HnswBinaryQuantizedBFloat16VectorsFormatTests extends ES93HnswBinaryQuantizedVectorsFormatTests { @Override - boolean useBFloat16() { - return true; + DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.BFLOAT16; } @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index 45e489662f3bf..51128bc7a0b8e 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -47,6 +47,7 @@ import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Arrays; @@ -67,13 +68,13 @@ public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFor private KnnVectorsFormat format; - boolean useBFloat16() { - return false; + DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; } @Override public void setUp() throws Exception { - format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, useBFloat16(), random().nextBoolean()); + format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, elementType(), random().nextBoolean()); super.setUp(); } @@ -86,7 +87,7 @@ public void testToString() { FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { @Override public KnnVectorsFormat knnVectorsFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, false, false, 1, null); + return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, elementType(), false, 1, null); } }; String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat," @@ -142,15 +143,15 @@ public void testSingleVectorCase() throws Exception { } public void testLimits() { - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, elementType(), false)); expectThrows( IllegalArgumentException.class, - () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService()) + () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, elementType(), false, 1, new SameThreadExecutorService()) ); } @@ -194,7 +195,11 @@ public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, b assertEquals(1L, (long) offHeap.get("vex")); assertTrue(offHeap.get("veb") > 0L); if (expectVecOffHeap) { - int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; + int bytes = switch (elementType()) { + case FLOAT -> Float.BYTES; + case BFLOAT16 -> BFloat16.BYTES; + default -> throw new AssertionError(); + }; assertEquals(vector.length * bytes, (long) offHeap.get("vec")); } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java index bd7e4f7f653bf..6f43c05e84f9f 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java @@ -27,6 +27,7 @@ import org.apache.lucene.util.SameThreadExecutorService; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Locale; @@ -47,13 +48,13 @@ public class ES93HnswVectorsFormatTests extends BaseKnnVectorsFormatTestCase { private KnnVectorsFormat format; - protected boolean useBFloat16() { - return false; + DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; } @Override public void setUp() throws Exception { - format = new ES93HnswVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, useBFloat16(), random().nextBoolean()); + format = new ES93HnswVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, elementType(), random().nextBoolean()); super.setUp(); } @@ -66,7 +67,7 @@ public void testToString() { FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { @Override public KnnVectorsFormat knnVectorsFormat() { - return new ES93HnswVectorsFormat(10, 20, false, false); + return new ES93HnswVectorsFormat(10, 20, elementType(), false); } }; String expectedPattern = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20," @@ -78,15 +79,15 @@ public KnnVectorsFormat knnVectorsFormat() { } public void testLimits() { - expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(-1, 20, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(0, 20, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 0, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, -1, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(512 + 1, 20, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 3201, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(-1, 20, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(0, 20, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 0, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, -1, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(512 + 1, 20, elementType(), false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 3201, elementType(), false)); expectThrows( IllegalArgumentException.class, - () -> new ES93HnswVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService()) + () -> new ES93HnswVectorsFormat(20, 100, elementType(), false, 1, new SameThreadExecutorService()) ); } @@ -106,7 +107,11 @@ public void testSimpleOffHeapSize() throws IOException { } var fieldInfo = r.getFieldInfos().fieldInfo("f"); var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); - int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; + int bytes = switch (elementType()) { + case FLOAT -> Float.BYTES; + case BFLOAT16 -> BFloat16.BYTES; + default -> throw new AssertionError(); + }; assertEquals(vector.length * bytes, (long) offHeap.get("vec")); assertEquals(1L, (long) offHeap.get("vex")); assertEquals(2, offHeap.size()); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 7500b902f115c..1c517fbda0408 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -47,6 +47,7 @@ import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.search.profile.query.QueryProfiler; import org.elasticsearch.test.ESTestCase; @@ -286,7 +287,7 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th new ES920DiskBBQVectorsFormat( DEFAULT_VECTORS_PER_CLUSTER, DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - randomBoolean(), + randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), randomBoolean() ), new ES818BinaryQuantizedVectorsFormat(), @@ -294,7 +295,7 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th new ES93HnswBinaryQuantizedVectorsFormat( DEFAULT_VECTORS_PER_CLUSTER, DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - randomBoolean(), + randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), randomBoolean() ), new ES813Int8FlatVectorFormat(), From 0b8fcf22cfb49b37fd1bf05d94221decbccd2aa2 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 15 Oct 2025 14:07:12 +0100 Subject: [PATCH 22/46] Add basic HNSW support --- .../vectors/DenseVectorFieldMapper.java | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index a002c68e1a956..d6816941e6ac7 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -56,6 +56,7 @@ import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -1314,17 +1315,14 @@ public enum VectorIndexType { public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - if (mNode == null) { - mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; - } - if (efConstructionNode == null) { - efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; - } + Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); - int m = XContentMapValues.nodeIntegerValue(mNode); - int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); + int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + boolean onDiskRescore = XContentMapValues.nodeBooleanValue(onDiskRescoreNode, false); MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new HnswIndexOptions(m, efConstruction); + + return new HnswIndexOptions(m, efConstruction, onDiskRescore); } @Override @@ -1482,16 +1480,10 @@ public boolean supportsDimension(int dims) { public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - if (mNode == null) { - mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; - } - if (efConstructionNode == null) { - efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; - } Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); - int m = XContentMapValues.nodeIntegerValue(mNode); - int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); + int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); boolean onDiskRescore = XContentMapValues.nodeBooleanValue(onDiskRescoreNode, false); RescoreVector rescoreVector = null; @@ -1588,7 +1580,7 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map new ES815HnswBitVectorsFormat(m, efConstruction); + case BYTE -> new Lucene99HnswVectorsFormat(m, efConstruction); + case FLOAT, BFLOAT16 -> new ES93HnswVectorsFormat(m, efConstruction, elementType, onDiskRescore); + }; } @Override From 2766e4b6c68903bdadc95ab798ed27a6d6b9abfb Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 23 Oct 2025 16:55:30 +0100 Subject: [PATCH 23/46] Use new formats --- .../diskbbq/ES920DiskBBQVectorsFormat.java | 1 - .../es93/ES93GenericFlatVectorsFormat.java | 1 - .../vectors/DenseVectorFieldMapper.java | 44 ++++++++++++++++--- .../vectors/DenseVectorFieldTypeTests.java | 6 +-- .../vectors/RescoreKnnVectorQueryTests.java | 3 +- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index 20c74653b61dc..25b711da8c18f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -92,7 +92,6 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu this(vectorPerCluster, centroidsPerParentCluster, DenseVectorFieldMapper.ElementType.FLOAT, false); } - // TODO: ElementType public ES920DiskBBQVectorsFormat( int vectorPerCluster, int centroidsPerParentCluster, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index a4645e332393a..191da9fd24778 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -17,7 +17,6 @@ import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Map; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index fabdff1bcbe14..bb8f8d22f6d75 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -55,6 +55,7 @@ import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; @@ -1992,9 +1993,24 @@ public static class HnswIndexOptions extends DenseVectorIndexOptions { @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType) { return switch (elementType) { - case BIT -> new ES815HnswBitVectorsFormat(m, efConstruction); - case BYTE -> new Lucene99HnswVectorsFormat(m, efConstruction); - case FLOAT, BFLOAT16 -> new ES93HnswVectorsFormat(m, efConstruction, elementType, onDiskRescore); + case BIT -> new ES93HnswVectorsFormat( + m, + efConstruction, + ES93GenericFlatVectorsFormat.ElementType.BIT, + onDiskRescore + ); + case BYTE, FLOAT -> new ES93HnswVectorsFormat( + m, + efConstruction, + ES93GenericFlatVectorsFormat.ElementType.STANDARD, + onDiskRescore + ); + case BFLOAT16 -> new ES93HnswVectorsFormat( + m, + efConstruction, + ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, + onDiskRescore + ); }; } @@ -2069,7 +2085,21 @@ public BBQHnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, Res @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - return new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, elementType, onDiskRescore); + return switch (elementType) { + case FLOAT -> new ES93HnswBinaryQuantizedVectorsFormat( + m, + efConstruction, + ES93GenericFlatVectorsFormat.ElementType.STANDARD, + onDiskRescore + ); + case BFLOAT16 -> new ES93HnswBinaryQuantizedVectorsFormat( + m, + efConstruction, + ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, + onDiskRescore + ); + case BYTE, BIT -> throw new AssertionError(); + }; } @Override @@ -2134,7 +2164,11 @@ static class BBQFlatIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - return new ES93BinaryQuantizedVectorsFormat(elementType, false); + return switch (elementType) { + case FLOAT -> new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, false); + case BFLOAT16 -> new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, false); + case BYTE, BIT -> throw new AssertionError(); + }; } @Override diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index fe3ed3afd788a..3686d5f802fac 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -65,7 +65,7 @@ private static DenseVectorFieldMapper.RescoreVector randomRescoreVector() { private static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsNonQuantized() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomBoolean()), new DenseVectorFieldMapper.FlatIndexOptions() ); } @@ -86,7 +86,7 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions randomFlatIndexOpti public static DenseVectorFieldMapper.DenseVectorIndexOptions randomGpuSupportedIndexOptions() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 3199)), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 3199), randomBoolean()), new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 3199), @@ -108,7 +108,7 @@ public static DenseVectorFieldMapper.VectorSimilarity randomGPUSupportedSimilari public static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsAll() { List options = new ArrayList<>( Arrays.asList( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomBoolean()), new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 1c517fbda0408..7d42570498f5f 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -45,6 +45,7 @@ import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -295,7 +296,7 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th new ES93HnswBinaryQuantizedVectorsFormat( DEFAULT_VECTORS_PER_CLUSTER, DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, - randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), + randomFrom(ES93GenericFlatVectorsFormat.ElementType.STANDARD, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16), randomBoolean() ), new ES813Int8FlatVectorFormat(), From f352602f9769327cfee1d0e9387f2826d2196ced Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 23 Oct 2025 17:07:15 +0100 Subject: [PATCH 24/46] Add normal HNSW --- .../elasticsearch/index/store/DirectIOIT.java | 11 +++++------ .../mapper/vectors/DenseVectorFieldMapper.java | 16 +++++++--------- .../vectors/DenseVectorFieldMapperTests.java | 6 +++--- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java index ae888f50155cd..b0bac4686eed0 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java @@ -73,7 +73,7 @@ protected boolean useDirectIO(String name, IOContext context, OptionalLong fileL @ParametersFactory public static Iterable parameters() { - return List.of(new Object[] { "bbq_hnsw" }, new Object[] { "bbq_disk" }); + return List.of(new Object[] { "hnsw" }, new Object[] { "bbq_hnsw" }, new Object[] { "bbq_disk" }); } public DirectIOIT(String type) { @@ -113,15 +113,14 @@ private String indexVectors(boolean directIO) { indexDoc(indexName, Integer.toString(i), "fooVector", IntStream.range(0, 64).mapToDouble(d -> randomFloat()).toArray()); } refresh(); - assertBBQIndexType(indexName, type); // test assertion to ensure that the correct index type is being used + assertIndexType(indexName, type); // test assertion to ensure that the correct index type is being used return indexName; } - @SuppressWarnings("unchecked") - static void assertBBQIndexType(String indexName, String type) { + static void assertIndexType(String indexName, String type) { var response = indicesAdmin().prepareGetFieldMappings(indexName).setFields("fooVector").get(); - var map = (Map) response.fieldMappings(indexName, "fooVector").sourceAsMap().get("fooVector"); - assertThat((String) ((Map) map.get("index_options")).get("type"), is(equalTo(type))); + var map = (Map) response.fieldMappings(indexName, "fooVector").sourceAsMap().get("fooVector"); + assertThat(((Map) map.get("index_options")).get("type"), is(equalTo(type))); } @TestLogging(value = "org.elasticsearch.index.store.FsDirectoryFactory:DEBUG", reason = "to capture trace logging for direct IO") diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index bb8f8d22f6d75..0e893b6c540ab 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1993,12 +1993,7 @@ public static class HnswIndexOptions extends DenseVectorIndexOptions { @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType) { return switch (elementType) { - case BIT -> new ES93HnswVectorsFormat( - m, - efConstruction, - ES93GenericFlatVectorsFormat.ElementType.BIT, - onDiskRescore - ); + case BIT -> new ES93HnswVectorsFormat(m, efConstruction, ES93GenericFlatVectorsFormat.ElementType.BIT, onDiskRescore); case BYTE, FLOAT -> new ES93HnswVectorsFormat( m, efConstruction, @@ -2035,6 +2030,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("type", type); builder.field("m", m); builder.field("ef_construction", efConstruction); + if (onDiskRescore) { + builder.field("on_disk_rescore", true); + } builder.endObject(); return builder; } @@ -2044,12 +2042,12 @@ public boolean doEquals(DenseVectorIndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; HnswIndexOptions that = (HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction; + return m == that.m && efConstruction == that.efConstruction && onDiskRescore == that.onDiskRescore; } @Override public int doHashCode() { - return Objects.hash(m, efConstruction); + return Objects.hash(m, efConstruction, onDiskRescore); } @Override @@ -2067,7 +2065,7 @@ public int efConstruction() { @Override public String toString() { - return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}"; + return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + ", on_disk_rescore=" + onDiskRescore + "}"; } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 0ef2a06f33391..128f5c2b55033 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -2764,12 +2764,12 @@ public void testKnnVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=" + String expectedString = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=" + (setM ? m : DEFAULT_MAX_CONN) + ", beamWidth=" + (setEfConstruction ? efConstruction : DEFAULT_BEAM_WIDTH) - + ", flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())" - + ")"; + + ", flatVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat" + + ", format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=DefaultFlatVectorScorer())))"; assertEquals(expectedString, knnVectorsFormat.toString()); } From ca4551c0868fc03637c0d32ee139fe64dc876fd9 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 24 Oct 2025 11:31:47 +0100 Subject: [PATCH 25/46] Update for merge --- .../DenseVectorFromBinaryBlockLoader.java | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java index f5f9e8dc88295..4bb94104f4f63 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java @@ -13,11 +13,15 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ShortBuffer; public class DenseVectorFromBinaryBlockLoader extends BlockDocValuesReader.DocValuesBlockLoader { private final String fieldName; @@ -50,6 +54,7 @@ public AllReader reader(LeafReaderContext context) throws IOException { } return switch (elementType) { case FLOAT -> new FloatDenseVectorFromBinary(docValues, dims, indexVersion); + case BFLOAT16 -> new BFloat16DenseVectorFromBinary(docValues, dims, indexVersion); case BYTE -> new ByteDenseVectorFromBinary(docValues, dims, indexVersion); case BIT -> new BitDenseVectorFromBinary(docValues, dims, indexVersion); }; @@ -132,6 +137,32 @@ public String toString() { } } + private static class BFloat16DenseVectorFromBinary extends AbstractDenseVectorFromBinary { + BFloat16DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + super(docValues, dims, indexVersion, new float[dims]); + } + + @Override + protected void writeScratchToBuilder(float[] scratch, BlockLoader.FloatBuilder builder) { + for (float value : scratch) { + builder.appendFloat(value); + } + } + + @Override + protected void decodeDenseVector(BytesRef bytesRef, float[] scratch) { + ShortBuffer sb = ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length) + .order(ByteOrder.LITTLE_ENDIAN) + .asShortBuffer(); + BFloat16.bFloat16ToFloat(sb, scratch); + } + + @Override + public String toString() { + return "BFloat16DenseVectorFromBinary.Bytes"; + } + } + private static class ByteDenseVectorFromBinary extends AbstractDenseVectorFromBinary { ByteDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { this(docValues, dims, indexVersion, dims); From 8f92014c5c55c90946027a1edbfe9f3dbdadb857 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 24 Oct 2025 13:07:52 +0100 Subject: [PATCH 26/46] Update tests --- .../index/codec/vectors/BFloat16.java | 4 + .../vectors/DenseVectorFieldMapper.java | 47 ++++-- .../search/vectors/VectorData.java | 6 +- ...S920DiskBBQBFloat16VectorsFormatTests.java | 134 ++++++++++-------- .../ES920DiskBBQVectorsFormatTests.java | 28 +--- .../vectors/DenseVectorFieldMapperTests.java | 36 +++-- .../vectors/DenseVectorFieldTypeTests.java | 2 + .../vectors/ExactKnnQueryBuilderTests.java | 2 +- .../services/cohere/CohereService.java | 2 +- .../mapper/RankVectorsFieldMapper.java | 2 +- .../mapper/RankVectorsFieldMapperTests.java | 35 +++-- 11 files changed, 175 insertions(+), 123 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java index 8d25ab54d8ca1..11eaf69344c97 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java @@ -29,6 +29,10 @@ public static short floatToBFloat16(float f) { return (short) (Float.floatToIntBits(f) >>> 16); } + public static float truncateToBFloat16(float f) { + return Float.intBitsToFloat(Float.floatToIntBits(f) & 0xffff0000); + } + public static float bFloat16ToFloat(short bf) { return Float.intBitsToFloat(bf << 16); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 97a483f27a6c3..6bac155fa547f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -127,16 +127,10 @@ */ public class DenseVectorFieldMapper extends FieldMapper { public static final String COSINE_MAGNITUDE_FIELD_SUFFIX = "._magnitude"; - private static final float EPS = 1e-3f; public static final int BBQ_MIN_DIMS = 64; private static final boolean DEFAULT_HNSW_EARLY_TERMINATION = false; - public static boolean isNotUnitVector(float magnitude) { - // TODO: need different EPS for bfloat16? - return Math.abs(magnitude - 1.0f) > EPS; - } - /** * The heuristic to utilize when executing a filtered search against vectors indexed in an HNSW graph. */ @@ -540,6 +534,8 @@ public static ElementType checkValidVector(float[] vector, ElementType... possib public abstract void writeValue(ByteBuffer byteBuffer, float value); + public abstract void writeValues(ByteBuffer byteBuffer, float[] values); + public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException; abstract IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldType, FieldDataContext fieldDataContext); @@ -557,6 +553,10 @@ public abstract VectorData parseKnnVector( public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes); + public boolean isUnitVector(float squaredMagnitude) { + return Math.abs(squaredMagnitude - 1.0f) < 1e-3f; + } + public void checkVectorBounds(float[] vector) { StringBuilder errors = checkVectorErrors(vector); if (errors != null) { @@ -641,6 +641,13 @@ public void writeValue(ByteBuffer byteBuffer, float value) { byteBuffer.put((byte) value); } + @Override + public void writeValues(ByteBuffer byteBuffer, float[] values) { + for (float f : values) { + byteBuffer.put((byte) f); + } + } + @Override public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException { b.value(byteBuffer.get()); @@ -893,6 +900,12 @@ public void writeValue(ByteBuffer byteBuffer, float value) { byteBuffer.putFloat(value); } + @Override + public void writeValues(ByteBuffer byteBuffer, float[] values) { + byteBuffer.asFloatBuffer().put(values); + byteBuffer.position(byteBuffer.position() + (values.length * Float.BYTES)); + } + @Override public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException { b.value(byteBuffer.getFloat()); @@ -956,7 +969,7 @@ void checkVectorMagnitude(VectorSimilarity similarity, UnaryOperator but was:<([0-9.-]+)>"); - - private static void assertFloatsWithinBounds(AssertionError error) { - Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); - if (m.matches() == false) { - throw error; // nothing to do with us, just rethrow + protected void assertOffHeapByteSize(LeafReader r, String fieldName) throws IOException { + var fieldInfo = r.getFieldInfos().fieldInfo(fieldName); + + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader(fieldName); + } + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + long totalByteSize = offHeap.values().stream().mapToLong(Long::longValue).sum(); + // IVF doesn't report stats at the moment + assertThat(offHeap, anEmptyMap()); + assertThat(totalByteSize, equalTo(0L)); + } else { + throw new AssertionError("unexpected:" + r.getClass()); } - - // numbers just need to be in the same vicinity - double expected = Double.parseDouble(m.group(1)); - double actual = Double.parseDouble(m.group(2)); - double allowedError = expected * 0.01; // within 1% - assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index be512ba2f5872..9b452fe4fb9cb 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -11,7 +11,6 @@ import com.carrotsearch.randomizedtesting.generators.RandomPicks; import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; @@ -37,11 +36,11 @@ import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; import org.junit.Before; import java.io.IOException; import java.util.List; -import java.util.Locale; import java.util.concurrent.atomic.AtomicBoolean; import static java.lang.String.format; @@ -52,8 +51,7 @@ import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MIN_VECTORS_PER_CLUSTER; import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.oneOf; +import static org.hamcrest.Matchers.hasToString; public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase { @@ -64,10 +62,6 @@ public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase private KnnVectorsFormat format; - DenseVectorFieldMapper.ElementType elementType() { - return DenseVectorFieldMapper.ElementType.FLOAT; - } - @Before @Override public void setUp() throws Exception { @@ -75,7 +69,7 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), - elementType(), + DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean() ); } else { @@ -83,7 +77,7 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), - elementType(), + DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean() ); } @@ -109,7 +103,7 @@ protected VectorEncoding randomVectorEncoding() { @Override public void testSearchWithVisitedLimit() { - // ivf doesn't enforce visitation limit + throw new AssumptionViolatedException("ivf doesn't enforce visitation limit"); } @Override @@ -142,17 +136,9 @@ public void testAdvance() throws Exception { } public void testToString() { - FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { - @Override - public KnnVectorsFormat knnVectorsFormat() { - return new ES920DiskBBQVectorsFormat(128, 4); - } - }; - String expectedPattern = "ES920DiskBBQVectorsFormat(vectorPerCluster=128)"; + KnnVectorsFormat format = new ES920DiskBBQVectorsFormat(128, 4); - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); - var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); - assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + assertThat(format, hasToString("ES920DiskBBQVectorsFormat(vectorPerCluster=128)")); } public void testLimits() { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 128f5c2b55033..2e0aef6037636 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -31,6 +31,7 @@ import org.elasticsearch.index.codec.CodecService; import org.elasticsearch.index.codec.LegacyPerFieldMapperCodec; import org.elasticsearch.index.codec.PerFieldMapperCodec; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentParsingException; @@ -84,11 +85,14 @@ public class DenseVectorFieldMapperTests extends SyntheticVectorsMapperTestCase private final int dims; public DenseVectorFieldMapperTests() { - this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT); + this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT); this.indexed = usually(); this.indexOptionsSet = this.indexed && randomBoolean(); int baseDims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4; - int randomMultiplier = ElementType.FLOAT == elementType ? randomIntBetween(1, 64) : 1; + int randomMultiplier = switch (elementType) { + case FLOAT, BFLOAT16 -> randomIntBetween(1, 64); + case BYTE, BIT -> 1; + }; this.dims = baseDims * randomMultiplier; } @@ -148,9 +152,12 @@ private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws I @Override protected Object getSampleValueForDocument() { - return elementType == ElementType.FLOAT - ? convertToList(randomNormalizedVector(this.dims)) - : convertToList(randomByteArrayOfLength(elementType == ElementType.BIT ? this.dims / Byte.SIZE : dims)); + return switch (elementType) { + case FLOAT -> convertToList(randomNormalizedVector(this.dims)); + case BFLOAT16 -> convertToBFloat16List(randomNormalizedVector(this.dims)); + case BYTE -> convertToList(randomByteArrayOfLength(dims)); + case BIT -> convertToList(randomByteArrayOfLength(this.dims / Byte.SIZE)); + }; } public static List convertToList(float[] vector) { @@ -161,6 +168,14 @@ public static List convertToList(float[] vector) { return list; } + public static List convertToBFloat16List(float[] vector) { + List list = new ArrayList<>(vector.length); + for (float v : vector) { + list.add(BFloat16.truncateToBFloat16(v)); + } + return list; + } + public static List convertToList(byte[] vector) { List list = new ArrayList<>(vector.length); for (byte v : vector) { @@ -3037,24 +3052,23 @@ protected boolean supportsEmptyInputArray() { private static class DenseVectorSyntheticSourceSupport implements SyntheticSourceSupport { private final int dims = between(5, 1000); - private final ElementType elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT); + private final ElementType elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT); private final boolean indexed = randomBoolean(); private final boolean indexOptionsSet = indexed && randomBoolean(); @Override public SyntheticSourceExample example(int maxValues) throws IOException { Object value = switch (elementType) { - case BYTE, BIT: - yield randomList(dims, dims, ESTestCase::randomByte); - case FLOAT, BFLOAT16: - yield randomList(dims, dims, ESTestCase::randomFloat); + case BYTE, BIT -> randomList(dims, dims, ESTestCase::randomByte); + case FLOAT -> randomList(dims, dims, ESTestCase::randomFloat); + case BFLOAT16 -> randomList(dims, dims, () -> BFloat16.truncateToBFloat16(randomFloat())); }; return new SyntheticSourceExample(value, value, this::mapping); } private void mapping(XContentBuilder b) throws IOException { b.field("type", "dense_vector"); - if (elementType == ElementType.BYTE || elementType == ElementType.BIT || randomBoolean()) { + if (elementType != ElementType.FLOAT || randomBoolean()) { b.field("element_type", elementType.toString()); } b.field("dims", elementType == ElementType.BIT ? dims * Byte.SIZE : dims); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 8eb782a9bc848..fc5bb65789cc3 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER; import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MIN_VECTORS_PER_CLUSTER; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BFLOAT16; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BIT; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; @@ -788,6 +789,7 @@ public void testRescoreOversampleQueryOverrides() { public void testFilterSearchThreshold() { List>> cases = List.of( Tuple.tuple(FLOAT, q -> ((ESKnnFloatVectorQuery) q).getStrategy()), + Tuple.tuple(BFLOAT16, q -> ((ESKnnFloatVectorQuery) q).getStrategy()), Tuple.tuple(BYTE, q -> ((ESKnnByteVectorQuery) q).getStrategy()), Tuple.tuple(BIT, q -> ((ESKnnByteVectorQuery) q).getStrategy()) ); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java index 5871c1d8f18c5..27728aaa550b1 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java @@ -101,7 +101,7 @@ protected void doAssertLuceneQuery(ExactKnnQueryBuilder queryBuilder, Query quer float[] expected = Arrays.copyOf(queryBuilder.getQuery().asFloatVector(), queryBuilder.getQuery().asFloatVector().length); float magnitude = VectorUtil.dotProduct(expected, expected); if (context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.NORMALIZED_VECTOR_COSINE) - && DenseVectorFieldMapper.isNotUnitVector(magnitude)) { + && DenseVectorFieldMapper.FLOAT_ELEMENT.isUnitVector(magnitude) == false) { VectorUtil.l2normalize(expected); assertArrayEquals(expected, denseVectorQuery.getQuery(), 0.0f); } else { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 5fcb23e1fa23b..a0cb0ea1016be 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -333,7 +333,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { /** * Returns the default similarity measure for the embedding type. * Cohere embeddings are expected to be normalized to unit vectors, but due to floating point precision issues, - * our check ({@link DenseVectorFieldMapper#isNotUnitVector(float)}) often fails. + * our check ({@link DenseVectorFieldMapper.Element#isUnitVector(float)}) often fails. * Therefore, we use cosine similarity to ensure compatibility. * * @return The default similarity measure. diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java index adb925757b6ca..acf920eabb268 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java @@ -342,7 +342,7 @@ public void parse(DocumentParserContext context) throws IOException { ByteBuffer buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.LITTLE_ENDIAN); ByteBuffer magnitudeBuffer = ByteBuffer.allocate(vectors.size() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); for (VectorData vector : vectors) { - vector.addToBuffer(buffer); + vector.addToBuffer(element, buffer); magnitudeBuffer.putFloat((float) Math.sqrt(element.computeSquaredMagnitude(vector))); } String vectorFieldName = fieldType().name(); diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java index 585f98983ea30..88517a860bb60 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.LuceneDocument; @@ -48,6 +49,7 @@ import java.util.stream.Stream; import static org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase.randomNormalizedVector; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTests.convertToBFloat16List; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTests.convertToList; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -61,9 +63,12 @@ public class RankVectorsFieldMapperTests extends SyntheticVectorsMapperTestCase private final int dims; public RankVectorsFieldMapperTests() { - this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT); + this.elementType = ElementType.BFLOAT16;// randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT); int baseDims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4; - int randomMultiplier = ElementType.FLOAT == elementType ? randomIntBetween(1, 64) : 1; + int randomMultiplier = switch (elementType) { + case FLOAT, BFLOAT16 -> randomIntBetween(1, 64); + case BYTE, BIT -> 1; + }; this.dims = baseDims * randomMultiplier; } @@ -92,11 +97,12 @@ private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws I @Override protected Object getSampleValueForDocument() { int numVectors = randomIntBetween(1, 16); - return Stream.generate( - () -> elementType == ElementType.FLOAT - ? convertToList(randomNormalizedVector(this.dims)) - : convertToList(randomByteArrayOfLength(elementType == ElementType.BIT ? this.dims / Byte.SIZE : dims)) - ).limit(numVectors).toList(); + return Stream.generate(() -> switch (elementType) { + case FLOAT -> convertToList(randomNormalizedVector(this.dims)); + case BFLOAT16 -> convertToBFloat16List(randomNormalizedVector(this.dims)); + case BYTE -> convertToList(randomByteArrayOfLength(dims)); + case BIT -> convertToList(randomByteArrayOfLength(this.dims / Byte.SIZE)); + }).limit(numVectors).toList(); } @Override @@ -467,22 +473,25 @@ protected boolean supportsEmptyInputArray() { private static class DenseVectorSyntheticSourceSupport implements SyntheticSourceSupport { private final int dims = between(5, 1000); private final int numVecs = between(1, 16); - private final ElementType elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT); + private final ElementType elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT); @Override public SyntheticSourceExample example(int maxValues) { Object value = switch (elementType) { - case BYTE, BIT: - yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomByte)); - case FLOAT, BFLOAT16: - yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomFloat)); + case BYTE, BIT -> randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomByte)); + case FLOAT -> randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomFloat)); + case BFLOAT16 -> randomList( + numVecs, + numVecs, + () -> randomList(dims, dims, () -> BFloat16.truncateToBFloat16(randomFloat())) + ); }; return new SyntheticSourceExample(value, value, this::mapping); } private void mapping(XContentBuilder b) throws IOException { b.field("type", "rank_vectors"); - if (elementType == ElementType.BYTE || elementType == ElementType.BIT || randomBoolean()) { + if (elementType != ElementType.FLOAT || randomBoolean()) { b.field("element_type", elementType.toString()); } b.field("dims", elementType == ElementType.BIT ? dims * Byte.SIZE : dims); From 8c697c3707dada168e164f5f1fd99752ae83755d Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 24 Oct 2025 13:59:23 +0100 Subject: [PATCH 27/46] Leave semantic text and rank vectors alone for the moment --- .../TestDenseInferenceServiceExtension.java | 2 +- .../elastic/ElasticTextEmbeddingPayload.java | 3 +- ...cInferenceMetadataFieldsRecoveryTests.java | 3 +- .../mapper/SemanticTextFieldTests.java | 11 +- .../mapper/RankVectorsDVLeafFieldData.java | 47 +------ .../mapper/RankVectorsFieldMapper.java | 2 +- .../script/RankVectorsScoreScriptUtils.java | 3 +- .../mapper/RankVectorsFieldMapperTests.java | 33 ++--- .../RankVectorsScriptDocValuesTests.java | 122 +----------------- .../upgrades/SemanticTextUpgradeIT.java | 13 -- 10 files changed, 37 insertions(+), 202 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index ff51f2b99670a..ebd881a9ed141 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -255,7 +255,7 @@ private static List generateEmbedding(String input, int dimensions, Dense // Copied from DenseVectorFieldMapperTestUtils due to dependency restrictions private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { return switch (elementType) { - case FLOAT, BYTE, BFLOAT16 -> dimensions; + case FLOAT, BFLOAT16, BYTE -> dimensions; case BIT -> { assert dimensions % Byte.SIZE == 0; yield dimensions / Byte.SIZE; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java index bf4db3f382912..78647304dcfa8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java @@ -97,7 +97,8 @@ public DenseEmbeddingResults responseBody(SageMakerModel model, InvokeEndpoin return switch (model.apiServiceSettings().elementType()) { case BIT -> TextEmbeddingBinary.PARSER.apply(p, null); case BYTE -> TextEmbeddingBytes.PARSER.apply(p, null); - case FLOAT, BFLOAT16 -> TextEmbeddingFloat.PARSER.apply(p, null); + case FLOAT -> TextEmbeddingFloat.PARSER.apply(p, null); + case BFLOAT16 -> throw new UnsupportedOperationException("Bfloat16 not supported"); }; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index 9bc1736a85c7b..1c677fb9532cd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -269,8 +269,9 @@ private static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BYTE, BIT -> randomChunkedInferenceEmbeddingByte(model, inputs); + case BFLOAT16 -> throw new AssertionError(); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 11411fe1ce83a..14963f9f82f3b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -190,8 +190,9 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbedding(Model mo return switch (model.getTaskType()) { case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); + case BFLOAT16 -> throw new AssertionError(); }; default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; @@ -222,7 +223,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Model model, List inputs) { DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType(); int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions()); - assert elementType == DenseVectorFieldMapper.ElementType.FLOAT || elementType == DenseVectorFieldMapper.ElementType.BFLOAT16; + assert elementType == DenseVectorFieldMapper.ElementType.FLOAT; List chunks = new ArrayList<>(); for (String input : inputs) { @@ -272,8 +273,9 @@ public static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); + case BFLOAT16 -> throw new AssertionError(); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); @@ -415,8 +417,9 @@ public static ChunkedInference toChunkedResult( ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText); double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType()); EmbeddingResults.Embedding embedding = switch (elementType) { - case FLOAT, BFLOAT16 -> new DenseEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); + case FLOAT -> new DenseEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); case BYTE, BIT -> new DenseEmbeddingByteResults.Embedding(byteArrayOf(values)); + case BFLOAT16 -> throw new AssertionError(); }; chunks.add(new EmbeddingResults.Chunk(embedding, offset)); } diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java index 60d6bfc5586ee..f56b974e0a95a 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java @@ -16,7 +16,6 @@ import org.elasticsearch.index.fielddata.SortedBinaryDocValues; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; -import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.BitRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; @@ -25,6 +24,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; final class RankVectorsDVLeafFieldData implements LeafFieldData { @@ -123,51 +123,12 @@ public Object nextValue() { VectorIterator iterator = new FloatRankVectorsDocValuesField.FloatVectorIterator(ref, vector, numVecs); while (iterator.hasNext()) { float[] v = iterator.next(); - vectors.add(v.clone()); - } - return vectors; - } - }; - case BFLOAT16 -> new FormattedDocValues() { - private final float[] vector = new float[dims]; - private BytesRef ref = null; - private int numVecs = -1; - private final BinaryDocValues binary; - { - try { - binary = DocValues.getBinary(reader, field); - } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values", e); - } - } - - @Override - public boolean advanceExact(int docId) throws IOException { - if (binary == null || binary.advanceExact(docId) == false) { - return false; - } - ref = binary.binaryValue(); - assert ref.length % (Short.BYTES * dims) == 0; - numVecs = ref.length / (Short.BYTES * dims); - return true; - } - - @Override - public int docValueCount() { - return 1; - } - - @Override - public Object nextValue() { - List vectors = new ArrayList<>(numVecs); - VectorIterator iterator = new BFloat16RankVectorsDocValuesField.BFloat16VectorIterator(ref, vector, numVecs); - while (iterator.hasNext()) { - float[] v = iterator.next(); - vectors.add(v.clone()); + vectors.add(Arrays.copyOf(v, v.length)); } return vectors; } }; + case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16"); }; } @@ -180,7 +141,7 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { case BYTE -> new ByteRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case FLOAT -> new FloatRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case BIT -> new BitRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); - case BFLOAT16 -> new BFloat16RankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); + case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16"); }; } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for multi-vector field!", e); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java index acf920eabb268..5c0ad8680a389 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java @@ -72,7 +72,7 @@ public static class Builder extends FieldMapper.Builder { () -> DenseVectorFieldMapper.ElementType.FLOAT, (n, c, o) -> { DenseVectorFieldMapper.ElementType elementType = namesToElementType.get((String) o); - if (elementType == null) { + if (elementType == null || elementType == ElementType.BFLOAT16) { throw new MapperParsingException( "invalid element_type [" + o + "]; available types are " + namesToElementType.keySet() ); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java index d846a54d0bc83..bd1c06f7c1dd1 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java @@ -351,12 +351,13 @@ public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fiel yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list); } } - case FLOAT, BFLOAT16 -> { + case FLOAT -> { if (queryVector instanceof List) { yield new MaxSimFloatDotProduct(scoreScript, field, (List>) queryVector); } throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName()); } + case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16"); }; } diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java index 88517a860bb60..33b2f084249b2 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java @@ -15,7 +15,6 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.LuceneDocument; @@ -49,7 +48,6 @@ import java.util.stream.Stream; import static org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase.randomNormalizedVector; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTests.convertToBFloat16List; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTests.convertToList; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -63,12 +61,9 @@ public class RankVectorsFieldMapperTests extends SyntheticVectorsMapperTestCase private final int dims; public RankVectorsFieldMapperTests() { - this.elementType = ElementType.BFLOAT16;// randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT); + this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT); int baseDims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4; - int randomMultiplier = switch (elementType) { - case FLOAT, BFLOAT16 -> randomIntBetween(1, 64); - case BYTE, BIT -> 1; - }; + int randomMultiplier = ElementType.FLOAT == elementType ? randomIntBetween(1, 64) : 1; this.dims = baseDims * randomMultiplier; } @@ -97,12 +92,11 @@ private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws I @Override protected Object getSampleValueForDocument() { int numVectors = randomIntBetween(1, 16); - return Stream.generate(() -> switch (elementType) { - case FLOAT -> convertToList(randomNormalizedVector(this.dims)); - case BFLOAT16 -> convertToBFloat16List(randomNormalizedVector(this.dims)); - case BYTE -> convertToList(randomByteArrayOfLength(dims)); - case BIT -> convertToList(randomByteArrayOfLength(this.dims / Byte.SIZE)); - }).limit(numVectors).toList(); + return Stream.generate( + () -> elementType == ElementType.FLOAT + ? convertToList(randomNormalizedVector(this.dims)) + : convertToList(randomByteArrayOfLength(elementType == ElementType.BIT ? this.dims / Byte.SIZE : dims)) + ).limit(numVectors).toList(); } @Override @@ -424,7 +418,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { } yield vectors; } - case FLOAT, BFLOAT16 -> { + case FLOAT -> { float[][] vectors = new float[numVectors][vectorFieldType.getVectorDimensions()]; for (int i = 0; i < numVectors; i++) { for (int j = 0; j < vectorFieldType.getVectorDimensions(); j++) { @@ -440,6 +434,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { } yield vectors; } + case BFLOAT16 -> throw new AssertionError(); }; } @@ -473,25 +468,21 @@ protected boolean supportsEmptyInputArray() { private static class DenseVectorSyntheticSourceSupport implements SyntheticSourceSupport { private final int dims = between(5, 1000); private final int numVecs = between(1, 16); - private final ElementType elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT); + private final ElementType elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT); @Override public SyntheticSourceExample example(int maxValues) { Object value = switch (elementType) { case BYTE, BIT -> randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomByte)); case FLOAT -> randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomFloat)); - case BFLOAT16 -> randomList( - numVecs, - numVecs, - () -> randomList(dims, dims, () -> BFloat16.truncateToBFloat16(randomFloat())) - ); + case BFLOAT16 -> throw new AssertionError(); }; return new SyntheticSourceExample(value, value, this::mapping); } private void mapping(XContentBuilder b) throws IOException { b.field("type", "rank_vectors"); - if (elementType != ElementType.FLOAT || randomBoolean()) { + if (elementType == ElementType.BYTE || elementType == ElementType.BIT || randomBoolean()) { b.field("element_type", elementType.toString()); } b.field("dims", elementType == ElementType.BIT ? dims * Byte.SIZE : dims); diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java index d8094195a4508..127ad6c7dbe43 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; -import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.RankVectors; @@ -53,36 +52,6 @@ public void testFloatGetVectorValueAndGetMagnitude() throws IOException { } } - public void testBFloat16GetVectorValueAndGetMagnitude() throws IOException { - int dims = 3; - float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; - float[][] expectedMagnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; - - BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); - BinaryDocValues magnitudeValues = wrap(expectedMagnitudes); - RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.BFLOAT16, - dims - ); - RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); - for (int i = 0; i < vectors.length; i++) { - field.setNextDocId(i); - assertEquals(vectors[i].length, field.size()); - assertEquals(dims, scriptDocValues.dims()); - Iterator iterator = scriptDocValues.getVectorValues(); - float[] magnitudes = scriptDocValues.getMagnitudes(); - assertEquals(expectedMagnitudes[i].length, magnitudes.length); - for (int j = 0; j < vectors[i].length; j++) { - assertTrue(iterator.hasNext()); - assertArrayEquals(vectors[i][j], iterator.next(), 0.0001f); - assertEquals(expectedMagnitudes[i][j], magnitudes[j], 0.0001f); - } - } - } - public void testByteGetVectorValueAndGetMagnitude() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -129,36 +98,6 @@ public void testFloatMetadataAndIterator() throws IOException { assertEquals(dv, RankVectors.EMPTY); } - public void testBFloat16MetadataAndIterator() throws IOException { - int dims = 3; - float[][][] vectors = new float[][][] { - fill(new float[3][dims], ElementType.BFLOAT16), - fill(new float[2][dims], ElementType.BFLOAT16) }; - float[][] magnitudes = new float[][] { new float[3], new float[2] }; - BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); - BinaryDocValues magnitudeValues = wrap(magnitudes); - - RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.BFLOAT16, - dims - ); - for (int i = 0; i < vectors.length; i++) { - field.setNextDocId(i); - RankVectors dv = field.get(); - assertEquals(vectors[i].length, dv.size()); - assertFalse(dv.isEmpty()); - assertEquals(dims, dv.getDims()); - UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator); - assertEquals("Cannot iterate over single valued rank_vectors field, use get() instead", e.getMessage()); - } - field.setNextDocId(vectors.length); - RankVectors dv = field.get(); - assertEquals(dv, RankVectors.EMPTY); - } - public void testByteMetadataAndIterator() throws IOException { int dims = 3; float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.BYTE), fill(new float[2][dims], ElementType.BYTE) }; @@ -207,30 +146,6 @@ public void testFloatMissingValues() throws IOException { assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); } - public void testBFloat16MissingValues() throws IOException { - int dims = 3; - float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; - float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; - BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); - BinaryDocValues magnitudeValues = wrap(magnitudes); - RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.BFLOAT16, - dims - ); - RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); - - field.setNextDocId(3); - assertEquals(0, field.size()); - Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues); - assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); - - e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes); - assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); - } - public void testByteMissingValues() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -269,32 +184,6 @@ public void testFloatGetFunctionIsNotAccessible() throws IOException { ); } - public void testBFloat16GetFunctionIsNotAccessible() throws IOException { - int dims = 3; - float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; - float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; - BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); - BinaryDocValues magnitudeValues = wrap(magnitudes); - RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.BFLOAT16, - dims - ); - RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); - - field.setNextDocId(0); - Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); - assertThat( - e.getMessage(), - containsString( - "accessing a rank-vectors field's value through 'get' or 'value' is not supported," - + " use 'vectorValues' or 'magnitudes' instead." - ) - ); - } - public void testByteGetFunctionIsNotAccessible() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -417,11 +306,12 @@ public static BytesRef mockEncodeDenseVector(float[][] values, ElementType eleme ByteBuffer byteBuffer = element.createByteBuffer(indexVersion, numBytes * values.length); for (float[] vector : values) { for (float value : vector) { - switch (elementType) { - case FLOAT -> byteBuffer.putFloat(value); - case BFLOAT16 -> byteBuffer.putShort((short) (Float.floatToIntBits(value) >>> 16)); - case BYTE, BIT -> byteBuffer.put((byte) value); - default -> throw new IllegalStateException("unknown element_type [" + elementType + "]"); + if (elementType == ElementType.FLOAT) { + byteBuffer.putFloat(value); + } else if (elementType == ElementType.BYTE || elementType == ElementType.BIT) { + byteBuffer.put((byte) value); + } else { + throw new IllegalStateException("unknown element_type [" + elementType + "]"); } } } diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java index 22236041a305a..9dcf5537ee9d5 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -17,7 +17,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; -import org.elasticsearch.index.mapper.MapperFeatures; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; import org.elasticsearch.index.query.NestedQueryBuilder; @@ -35,7 +34,6 @@ import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; -import org.junit.Before; import org.junit.BeforeClass; import java.io.IOException; @@ -85,17 +83,6 @@ public static Iterable parameters() { return List.of(new Object[] { true }, new Object[] { false }); } - private boolean runTest; - - @Before - public void checkSupport() { - if (CLUSTER_TYPE == ClusterType.OLD) { - runTest = DENSE_MODEL.getServiceSettings().elementType() != DenseVectorFieldMapper.ElementType.BFLOAT16 - || clusterHasFeature(MapperFeatures.HNSW_BFLOAT16_ON_DISK_RESCORING); - } - assumeTrue("Old cluster needs to support bfloat16", runTest); - } - public void testSemanticTextOperations() throws Exception { switch (CLUSTER_TYPE) { case OLD -> createAndPopulateIndex(); From 99395e55543900c1acb340bd48c8430f413c7e17 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 24 Oct 2025 16:03:20 +0100 Subject: [PATCH 28/46] Need some basic support in downstream classes --- .../mapper/SemanticInferenceMetadataFieldsRecoveryTests.java | 3 +-- .../xpack/inference/mapper/SemanticTextFieldTests.java | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index 1c677fb9532cd..9bc1736a85c7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -269,9 +269,8 @@ private static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BYTE, BIT -> randomChunkedInferenceEmbeddingByte(model, inputs); - case BFLOAT16 -> throw new AssertionError(); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 14963f9f82f3b..7fbdab369707a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -273,9 +273,8 @@ public static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); - case BFLOAT16 -> throw new AssertionError(); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); From 297db11d4a87b63b983861ad57edcc56c22b13ad Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 24 Oct 2025 16:28:20 +0100 Subject: [PATCH 29/46] Test updates --- .../search.vectors/40_knn_search_bfloat16.yml | 707 ++++++++++++++++++ .../46_knn_search_bbq_ivf_bfloat16.yml | 629 ++++++++++++++++ .../mapper/SemanticTextFieldTests.java | 2 +- 3 files changed, 1337 insertions(+), 1 deletion(-) create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml new file mode 100644 index 0000000000000..3703fa28c382e --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml @@ -0,0 +1,707 @@ +setup: + - requires: + cluster_features: ["mapper.vectors.hnsw_bfloat16_on_disk_rescoring"] + reason: 'bfloat16 needs to be supported' + - do: + indices.create: + index: test + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + dims: 5 + index: true + similarity: l2_norm + index_options: + type: hnsw + m: 16 + ef_construction: 200 + element_type: bfloat16 + another_vector: + type: dense_vector + dims: 5 + index: true + similarity: l2_norm + index_options: + type: hnsw + m: 16 + ef_construction: 200 + element_type: bfloat16 + - do: + index: + index: test + id: "1" + body: + name: cow.jpg + vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ] + another_vector: [ 130.0, 115.0, -1.02, 15.555, -100.0 ] + + - do: + index: + index: test + id: "2" + body: + name: moose.jpg + vector: [ -0.5, 100.0, -13, 14.8, -156.0 ] + another_vector: [ -0.5, 50.0, -1, 1, 120 ] + + - do: + index: + index: test + id: "3" + body: + name: rabbit.jpg + vector: [ 0.5, 111.3, -13.0, 14.8, -156.0 ] + another_vector: [ -0.5, 11.0, 0, 12, 111.0 ] + + - do: + indices.refresh: { } + +--- +"kNN search only": + - requires: + cluster_features: "gte_v8.4.0" + reason: 'kNN added to search endpoint in 8.4' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.fields.name.0: "moose.jpg" } + + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.1.fields.name.0: "rabbit.jpg" } +--- +"kNN multi-field search only": + - requires: + cluster_features: "gte_v8.7.0" + reason: 'multi-field kNN search added to search endpoint in 8.7' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + - { field: vector, query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ], k: 2, num_candidates: 3 } + - { field: another_vector, query_vector: [ -0.5, 11.0, 0, 12, 111.0 ], k: 2, num_candidates: 3 } + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.fields.name.0: "rabbit.jpg" } + + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.fields.name.0: "moose.jpg" } +--- +"kNN search plus query": + - requires: + cluster_features: "gte_v8.4.0" + reason: 'kNN added to search endpoint in 8.4' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + query: + term: + name: cow.jpg + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0.fields.name.0: "cow.jpg" } + + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.fields.name.0: "moose.jpg" } + + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2.fields.name.0: "rabbit.jpg" } +--- +"kNN multi-field search with query": + - requires: + cluster_features: "gte_v8.7.0" + reason: 'multi-field kNN search added to search endpoint in 8.7' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + - { field: vector, query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ], k: 2, num_candidates: 3 } + - { field: another_vector, query_vector: [ -0.5, 11.0, 0, 12, 111.0 ], k: 2, num_candidates: 3 } + query: + term: + name: cow.jpg + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.fields.name.0: "rabbit.jpg" } + + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1.fields.name.0: "cow.jpg" } + + - match: { hits.hits.2._id: "2" } + - match: { hits.hits.2.fields.name.0: "moose.jpg" } +--- +"kNN search with filter": + - requires: + cluster_features: "gte_v8.4.0" + reason: 'kNN added to search endpoint in 8.4' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + filter: + term: + name: "rabbit.jpg" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.fields.name.0: "rabbit.jpg" } + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + filter: + - term: + name: "rabbit.jpg" + - term: + _id: 2 + + - match: { hits.total.value: 0 } + +--- +"kNN search with explicit search_type": + - requires: + cluster_features: "gte_v8.4.0" + reason: 'kNN added to search endpoint in 8.4' + - do: + catch: bad_request + search: + index: test + search_type: query_then_fetch + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - match: { error.root_cause.0.type: "illegal_argument_exception" } + - match: { error.root_cause.0.reason: "cannot set [search_type] when using [knn] search, since the search type is determined automatically" } + +--- +"kNN search in _knn_search endpoint": + - skip: + features: [ "allowed_warnings", "headers" ] + - do: + headers: + Content-Type: "application/vnd.elasticsearch+json;compatible-with=8" + Accept: "application/vnd.elasticsearch+json;compatible-with=8" + allowed_warnings: + - "The kNN search API has been replaced by the `knn` option in the search API." + knn_search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.fields.name.0: "moose.jpg" } + + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.1.fields.name.0: "rabbit.jpg" } + +--- +"kNN search with filter in _knn_search endpoint": + - requires: + cluster_features: "gte_v8.2.0" + reason: 'kNN with filtering added in 8.2' + test_runner_features: [ "allowed_warnings", "headers" ] + - do: + headers: + Content-Type: "application/vnd.elasticsearch+json;compatible-with=8" + Accept: "application/vnd.elasticsearch+json;compatible-with=8" + allowed_warnings: + - "The kNN search API has been replaced by the `knn` option in the search API." + knn_search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + filter: + term: + name: "rabbit.jpg" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.fields.name.0: "rabbit.jpg" } + + - do: + headers: + Content-Type: "application/vnd.elasticsearch+json;compatible-with=8" + Accept: "application/vnd.elasticsearch+json;compatible-with=8" + allowed_warnings: + - "The kNN search API has been replaced by the `knn` option in the search API." + knn_search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + filter: + - term: + name: "rabbit.jpg" + - term: + _id: 2 + + - match: { hits.total.value: 0 } + +--- +"Test nonexistent field is match none": + - requires: + cluster_features: "gte_v8.16.0" + reason: 'non-existent field handling improved in 8.16' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nonexistent + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - length: { hits.hits: 0 } + + - do: + indices.create: + index: test_nonexistent + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + element_type: float + dims: 5 + index: true + similarity: l2_norm + settings: + index.query.parse.allow_unmapped_fields: false + + - do: + catch: bad_request + search: + index: test_nonexistent + body: + fields: [ "name" ] + knn: + field: nonexistent + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - match: { error.root_cause.0.type: "query_shard_exception" } + - match: { error.root_cause.0.reason: "No field mapping can be found for the field with name [nonexistent]" } + +--- +"KNN Vector similarity search only": + - requires: + cluster_features: "gte_v8.8.0" + reason: 'kNN similarity added in 8.8' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 11 + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + + - length: { hits.hits: 1 } + + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.fields.name.0: "moose.jpg" } +--- +"Vector similarity with filter only": + - requires: + cluster_features: "gte_v8.8.0" + reason: 'kNN similarity added in 8.8' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 11 + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + filter: { "term": { "name": "moose.jpg" } } + + - length: { hits.hits: 1 } + + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.fields.name.0: "moose.jpg" } + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 110 + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + filter: { "term": { "name": "cow.jpg" } } + + - length: { hits.hits: 0 } +--- +"Knn search with mip": + - requires: + cluster_features: "gte_v8.11.0" + reason: 'mip similarity added in 8.11' + test_runner_features: "close_to" + + - do: + indices.create: + index: mip + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + dims: 5 + index: true + similarity: max_inner_product + index_options: + type: hnsw + m: 16 + ef_construction: 200 + + - do: + index: + index: mip + id: "1" + body: + name: cow.jpg + vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ] + + - do: + index: + index: mip + id: "2" + body: + name: moose.jpg + vector: [ -0.5, 100.0, -13, 14.8, -156.0 ] + + - do: + index: + index: mip + id: "3" + body: + name: rabbit.jpg + vector: [ 0.5, 111.3, -13.0, 14.8, -156.0 ] + + - do: + indices.refresh: { } + + - do: + search: + index: mip + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + + + - length: { hits.hits: 3 } + - match: { hits.hits.0._id: "1" } + - close_to: { hits.hits.0._score: { value: 58694.902, error: 0.01 } } + - match: { hits.hits.1._id: "3" } + - close_to: { hits.hits.1._score: { value: 34702.79, error: 0.01 } } + - match: { hits.hits.2._id: "2" } + - close_to: { hits.hits.2._score: { value: 33686.29, error: 0.01 } } + + - do: + search: + index: mip + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + filter: { "term": { "name": "moose.jpg" } } + + + + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "2" } + - close_to: { hits.hits.0._score: { value: 33686.29, error: 0.01 } } +--- +"Knn search with _name": + - requires: + cluster_features: "gte_v8.15.0" + reason: 'support for _name in knn was added in 8.15' + test_runner_features: "close_to" + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + num_candidates: 3 + _name: "my_knn_query" + query: + term: + name: + term: cow.jpg + _name: "my_query" + + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0.fields.name.0: "cow.jpg" } + - match: { hits.hits.0.matched_queries.0: "my_knn_query" } + - match: { hits.hits.0.matched_queries.1: "my_query" } + + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.fields.name.0: "moose.jpg" } + - match: { hits.hits.1.matched_queries.0: "my_knn_query" } + + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2.fields.name.0: "rabbit.jpg" } + - match: { hits.hits.2.matched_queries.0: "my_knn_query" } + +--- +"kNN search on empty index should return 0 results and not an error": + - requires: + cluster_features: "gte_v8.15.1" + reason: 'Error fixed in 8.15.1' + - do: + indices.create: + index: test_empty + body: + mappings: + properties: + vector: + type: dense_vector + - do: + search: + index: test_empty + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - match: { hits.total.value: 0 } +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } + +--- +"Dimensions are dynamically set": + - do: + indices.create: + index: test_index + body: + mappings: + properties: + embedding: + type: dense_vector + + - do: + index: + index: test_index + id: "0" + refresh: true + body: + embedding: [ 0.5, 111.3, -13.0, 14.8, -156.0 ] + + # wait and ensure that the mapping update is replicated + - do: + cluster.health: + wait_for_events: languid + + - do: + indices.get_mapping: + index: test_index + + - match: { test_index.mappings.properties.embedding.type: dense_vector } + - match: { test_index.mappings.properties.embedding.dims: 5 } + + - do: + catch: bad_request + index: + index: test_index + id: "0" + body: + embedding: [ 0.5, 111.3 ] + +--- +"Updating dim to null is not allowed": + - requires: + cluster_features: "mapper.npe_on_dims_update_fix" + reason: "dims update fix" + - do: + indices.create: + index: test_index + + - do: + indices.put_mapping: + index: test_index + body: + properties: + embedding: + type: dense_vector + dims: 4 + - do: + catch: bad_request + indices.put_mapping: + index: test_index + body: + properties: + embedding: + type: dense_vector + + +--- +"Searching with no data dimensions specified": + - requires: + cluster_features: "search.vectors.no_dimensions_bugfix" + reason: "Search with no dimensions bugfix" + + - do: + indices.create: + index: empty-test + body: + mappings: + properties: + vector: + type: dense_vector + index: true + + - do: + search: + index: empty-test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + similarity: 0.1 + + - match: { hits.total.value: 0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml new file mode 100644 index 0000000000000..f90ae0575afe3 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml @@ -0,0 +1,629 @@ +setup: + - requires: + cluster_features: ["mapper.vectors.hnsw_bfloat16_on_disk_rescoring"] + reason: 'bfloat16 needs to be supported' + - skip: + features: "headers" + - do: + indices.create: + index: bbq_disk + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + element_type: bfloat16 + + - do: + index: + index: bbq_disk + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_disk + + - do: + index: + index: bbq_disk + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_disk + + - do: + index: + index: bbq_disk + id: "3" + body: + name: rabbit.jpg + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_disk + + - do: + indices.forcemerge: + index: bbq_disk + max_num_segments: 1 + + - do: + indices.refresh: { } +--- +"Test knn search": + - do: + search: + index: bbq_disk + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Test knn search with visit_percentage": + - do: + search: + index: bbq_disk + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + visit_percentage: 1.0 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_disk + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad quantization parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_ivf + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + element_type: byte + index: true + index_options: + type: bbq_disk + + - do: + catch: bad_request + indices.create: + index: bad_bbq_ivf + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: false + index_options: + type: bbq_disk +--- +"Test index configured rescore vector": + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_ivf + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_ivf + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_ivf + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test index configured rescore vector with on-disk rescore": + - requires: + cluster_features: [ "mapper.vectors.diskbbq_on_disk_rescoring" ] + reason: Needs on_disk_rescoring feature for DiskBBQ + - skip: + features: "headers" + - do: + indices.create: + index: bbq_on_disk_rescore_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + on_disk_rescore: true + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_on_disk_rescore_ivf + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_on_disk_rescore_ivf + body: + knn: + field: vector + query_vector: [ 0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158 ] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_on_disk_rescore_ivf + body: + query: + script_score: + query: { match_all: { } } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [ 0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158 ] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } +--- +"Test index configured rescore vector updateable and settable to 0": + - do: + indices.create: + index: bbq_rescore_0_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + index_options: + type: bbq_disk + rescore_vector: + oversample: 0 + + - do: + indices.create: + index: bbq_rescore_update_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + index_options: + type: bbq_disk + rescore_vector: + oversample: 1 + + - do: + indices.put_mapping: + index: bbq_rescore_update_ivf + body: + properties: + vector: + type: dense_vector + index_options: + type: bbq_disk + rescore_vector: + oversample: 0 + + - do: + indices.get_mapping: + index: bbq_rescore_update_ivf + + - match: { .bbq_rescore_update_ivf.mappings.properties.vector.index_options.rescore_vector.oversample: 0 } +--- +"Test index configured rescore vector score consistency": + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_zero_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + rescore_vector: + oversample: 0 + + - do: + bulk: + index: bbq_rescore_zero_ivf + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_ivf + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_ivf + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 2 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: override_score0 } + - set: { hits.hits.1._score: override_score1 } + - set: { hits.hits.2._score: override_score2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_ivf + body: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + rescore_vector: + oversample: 2 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_ivf + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: default_rescore0 } + - set: { hits.hits.1._score: default_rescore1 } + - set: { hits.hits.2._score: default_rescore2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_ivf + body: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + rescore_vector: + oversample: 0 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_ivf + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $override_score0 } + - match: { hits.hits.0._score: $default_rescore0 } + - match: { hits.hits.1._score: $override_score1 } + - match: { hits.hits.1._score: $default_rescore1 } + - match: { hits.hits.2._score: $override_score2 } + - match: { hits.hits.2._score: $default_rescore2 } + +--- +"default oversample value": + - do: + indices.get_mapping: + index: bbq_disk + + - match: { bbq_disk.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 7fbdab369707a..5eb64696b5917 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -223,7 +223,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Model model, List inputs) { DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType(); int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions()); - assert elementType == DenseVectorFieldMapper.ElementType.FLOAT; + assert elementType == DenseVectorFieldMapper.ElementType.FLOAT || elementType == DenseVectorFieldMapper.ElementType.BFLOAT16; List chunks = new ArrayList<>(); for (String input : inputs) { From 5965de33939a48579c2dffca87df0da5865c914d Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 24 Oct 2025 17:25:00 +0100 Subject: [PATCH 30/46] Non-quantized HNSW doesn't need direct IO Various test updates --- .../search.vectors/40_knn_search_bfloat16.yml | 4 +- .../46_knn_search_bbq_ivf_bfloat16.yml | 2 +- .../elasticsearch/index/store/DirectIOIT.java | 2 +- .../vectors/es93/ES93HnswVectorsFormat.java | 11 +- .../vectors/DenseVectorFieldMapper.java | 33 +--- .../BFloat16RankVectorsDocValuesField.java | 159 ------------------ .../ES93HnswBFloat16VectorsFormatTests.java | 13 +- .../es93/ES93HnswBitVectorsFormatTests.java | 4 +- .../es93/ES93HnswVectorsFormatTests.java | 13 +- .../vectors/DenseVectorFieldTypeTests.java | 6 +- .../vectors/RescoreKnnVectorQueryTests.java | 15 +- 11 files changed, 36 insertions(+), 226 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml index 3703fa28c382e..aa80764f14136 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml @@ -12,6 +12,7 @@ setup: type: keyword vector: type: dense_vector + element_type: bfloat16 dims: 5 index: true similarity: l2_norm @@ -19,9 +20,9 @@ setup: type: hnsw m: 16 ef_construction: 200 - element_type: bfloat16 another_vector: type: dense_vector + element_type: bfloat16 dims: 5 index: true similarity: l2_norm @@ -29,7 +30,6 @@ setup: type: hnsw m: 16 ef_construction: 200 - element_type: bfloat16 - do: index: index: test diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml index f90ae0575afe3..aa59bceb00598 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml @@ -15,12 +15,12 @@ setup: properties: vector: type: dense_vector + element_type: bfloat16 dims: 64 index: true similarity: max_inner_product index_options: type: bbq_disk - element_type: bfloat16 - do: index: diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java index b0bac4686eed0..c8ceb200ea0b3 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java @@ -73,7 +73,7 @@ protected boolean useDirectIO(String name, IOContext context, OptionalLong fileL @ParametersFactory public static Iterable parameters() { - return List.of(new Object[] { "hnsw" }, new Object[] { "bbq_hnsw" }, new Object[] { "bbq_disk" }); + return List.of(new Object[] { "bbq_hnsw" }, new Object[] { "bbq_disk" }); } public DirectIOIT(String type) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java index bd191f2dfed64..ab289e9a0f858 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java @@ -32,26 +32,25 @@ public ES93HnswVectorsFormat() { flatVectorsFormat = new ES93GenericFlatVectorsFormat(); } - public ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) { + public ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType) { super(NAME); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } - public ES93HnswVectorsFormat(int maxConn, int beamWidth, ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) { + public ES93HnswVectorsFormat(int maxConn, int beamWidth, ES93GenericFlatVectorsFormat.ElementType elementType) { super(NAME, maxConn, beamWidth); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } public ES93HnswVectorsFormat( int maxConn, int beamWidth, ES93GenericFlatVectorsFormat.ElementType elementType, - boolean useDirectIO, int numMergeWorkers, ExecutorService mergeExec ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 6bac155fa547f..0ff09e85c2b94 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1342,14 +1342,12 @@ public enum VectorIndexType { public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); - boolean onDiskRescore = XContentMapValues.nodeBooleanValue(onDiskRescoreNode, false); MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new HnswIndexOptions(m, efConstruction, onDiskRescore); + return new HnswIndexOptions(m, efConstruction); } @Override @@ -2007,31 +2005,19 @@ public boolean updatableTo(DenseVectorIndexOptions update) { public static class HnswIndexOptions extends DenseVectorIndexOptions { private final int m; private final int efConstruction; - private final boolean onDiskRescore; - HnswIndexOptions(int m, int efConstruction, boolean onDiskRescore) { + HnswIndexOptions(int m, int efConstruction) { super(VectorIndexType.HNSW); this.m = m; this.efConstruction = efConstruction; - this.onDiskRescore = onDiskRescore; } @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType) { return switch (elementType) { - case BIT -> new ES93HnswVectorsFormat(m, efConstruction, ES93GenericFlatVectorsFormat.ElementType.BIT, onDiskRescore); - case BYTE, FLOAT -> new ES93HnswVectorsFormat( - m, - efConstruction, - ES93GenericFlatVectorsFormat.ElementType.STANDARD, - onDiskRescore - ); - case BFLOAT16 -> new ES93HnswVectorsFormat( - m, - efConstruction, - ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, - onDiskRescore - ); + case BIT -> new ES93HnswVectorsFormat(m, efConstruction, ES93GenericFlatVectorsFormat.ElementType.BIT); + case BYTE, FLOAT -> new ES93HnswVectorsFormat(m, efConstruction, ES93GenericFlatVectorsFormat.ElementType.STANDARD); + case BFLOAT16 -> new ES93HnswVectorsFormat(m, efConstruction, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16); }; } @@ -2056,9 +2042,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("type", type); builder.field("m", m); builder.field("ef_construction", efConstruction); - if (onDiskRescore) { - builder.field("on_disk_rescore", true); - } builder.endObject(); return builder; } @@ -2068,12 +2051,12 @@ public boolean doEquals(DenseVectorIndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; HnswIndexOptions that = (HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction && onDiskRescore == that.onDiskRescore; + return m == that.m && efConstruction == that.efConstruction; } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, onDiskRescore); + return Objects.hash(m, efConstruction); } @Override @@ -2091,7 +2074,7 @@ public int efConstruction() { @Override public String toString() { - return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + ", on_disk_rescore=" + onDiskRescore + "}"; + return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}"; } } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java deleted file mode 100644 index 4759da76bc75b..0000000000000 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.script.field.vectors; - -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.index.codec.vectors.BFloat16; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.ShortBuffer; -import java.util.Iterator; - -public class BFloat16RankVectorsDocValuesField extends RankVectorsDocValuesField { - - private final BinaryDocValues input; - private final BinaryDocValues magnitudes; - private boolean decoded; - private final int dims; - private BytesRef value; - private BytesRef magnitudesValue; - private BFloat16VectorIterator vectorValues; - private int numVectors; - private final float[] buffer; - - public BFloat16RankVectorsDocValuesField( - BinaryDocValues input, - BinaryDocValues magnitudes, - String name, - ElementType elementType, - int dims - ) { - super(name, elementType); - this.input = input; - this.magnitudes = magnitudes; - this.dims = dims; - this.buffer = new float[dims]; - } - - @Override - public void setNextDocId(int docId) throws IOException { - decoded = false; - if (input.advanceExact(docId)) { - boolean magnitudesFound = magnitudes.advanceExact(docId); - assert magnitudesFound; - - value = input.binaryValue(); - assert value.length % (BFloat16.BYTES * dims) == 0; - numVectors = value.length / (BFloat16.BYTES * dims); - magnitudesValue = magnitudes.binaryValue(); - assert magnitudesValue.length == (Float.BYTES * numVectors); - } else { - value = null; - magnitudesValue = null; - numVectors = 0; - } - } - - @Override - public RankVectorsScriptDocValues toScriptDocValues() { - return new RankVectorsScriptDocValues(this, dims); - } - - @Override - public boolean isEmpty() { - return value == null; - } - - @Override - public RankVectors get() { - if (isEmpty()) { - return RankVectors.EMPTY; - } - decodeVectorIfNecessary(); - return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); - } - - @Override - public RankVectors get(RankVectors defaultValue) { - if (isEmpty()) { - return defaultValue; - } - decodeVectorIfNecessary(); - return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); - } - - @Override - public RankVectors getInternal() { - return get(null); - } - - @Override - public int size() { - return value == null ? 0 : value.length / (BFloat16.BYTES * dims); - } - - private void decodeVectorIfNecessary() { - if (decoded == false && value != null) { - vectorValues = new BFloat16VectorIterator(value, buffer, numVectors); - decoded = true; - } - } - - public static class BFloat16VectorIterator implements VectorIterator { - private final float[] buffer; - private final ShortBuffer vectorValues; - private final BytesRef vectorValueBytesRef; - private final int size; - private int idx = 0; - - public BFloat16VectorIterator(BytesRef vectorValues, float[] buffer, int size) { - assert vectorValues.length == (buffer.length * BFloat16.BYTES * size); - this.vectorValueBytesRef = vectorValues; - this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length) - .order(ByteOrder.LITTLE_ENDIAN) - .asShortBuffer(); - this.size = size; - this.buffer = buffer; - } - - @Override - public boolean hasNext() { - return idx < size; - } - - @Override - public float[] next() { - if (hasNext() == false) { - throw new IllegalArgumentException("No more elements in the iterator"); - } - for (int i = 0; i < buffer.length; i++) { - buffer[i] = BFloat16.bFloat16ToFloat(vectorValues.get()); - } - idx++; - return buffer; - } - - @Override - public Iterator copy() { - return new BFloat16VectorIterator(vectorValueBytesRef, new float[buffer.length], size); - } - - @Override - public void reset() { - idx = 0; - vectorValues.rewind(); - } - } -} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java index e6c7ab2f256d4..46956a279fc26 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java @@ -30,24 +30,17 @@ public class ES93HnswBFloat16VectorsFormatTests extends BaseHnswBFloat16VectorsF @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, random().nextBoolean()); + return new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { - return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, random().nextBoolean()); + return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { - return new ES93HnswVectorsFormat( - maxConn, - beamWidth, - ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, - random().nextBoolean(), - numMergeWorkers, - service - ); + return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, numMergeWorkers, service); } public void testToString() { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java index f5f15bb1f06df..2d25d04289cab 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java @@ -36,9 +36,7 @@ public class ES93HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCa @Override protected Codec getCodec() { - return TestUtil.alwaysKnnVectorsFormat( - new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BIT, random().nextBoolean()) - ); + return TestUtil.alwaysKnnVectorsFormat(new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BIT)); } @Before diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java index 5fa507c23d756..3fe512038cad5 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java @@ -29,24 +29,17 @@ public class ES93HnswVectorsFormatTests extends BaseHnswVectorsFormatTestCase { @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, random().nextBoolean()); + return new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { - return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.STANDARD, random().nextBoolean()); + return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.STANDARD); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { - return new ES93HnswVectorsFormat( - maxConn, - beamWidth, - ES93GenericFlatVectorsFormat.ElementType.STANDARD, - random().nextBoolean(), - numMergeWorkers, - service - ); + return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.STANDARD, numMergeWorkers, service); } public void testToString() { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index fc5bb65789cc3..583472a6ff076 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -66,7 +66,7 @@ private static DenseVectorFieldMapper.RescoreVector randomRescoreVector() { private static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsNonQuantized() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomBoolean()), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), new DenseVectorFieldMapper.FlatIndexOptions() ); } @@ -87,7 +87,7 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions randomFlatIndexOpti public static DenseVectorFieldMapper.DenseVectorIndexOptions randomGpuSupportedIndexOptions() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 3199), randomBoolean()), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 3199)), new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 3199), @@ -109,7 +109,7 @@ public static DenseVectorFieldMapper.VectorSimilarity randomGPUSupportedSimilari public static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsAll() { List options = new ArrayList<>( Arrays.asList( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomBoolean()), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 7d42570498f5f..337c3875deb42 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -43,8 +43,7 @@ import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; @@ -291,11 +290,15 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), randomBoolean() ), - new ES818BinaryQuantizedVectorsFormat(), - new ES818HnswBinaryQuantizedVectorsFormat(), + new ES93BinaryQuantizedVectorsFormat( + randomFrom(ES93GenericFlatVectorsFormat.ElementType.STANDARD, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16), + randomBoolean() + ), + new ES93HnswBinaryQuantizedVectorsFormat( + randomFrom(ES93GenericFlatVectorsFormat.ElementType.STANDARD, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16), + randomBoolean() + ), new ES93HnswBinaryQuantizedVectorsFormat( - DEFAULT_VECTORS_PER_CLUSTER, - DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, randomFrom(ES93GenericFlatVectorsFormat.ElementType.STANDARD, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16), randomBoolean() ), From 2bd4423ebe07cedae5256869268023910011d5b7 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 30 Oct 2025 10:10:49 +0000 Subject: [PATCH 31/46] Use DenseVectorFieldMapper ElementType --- .../ES93BinaryQuantizedVectorsFormat.java | 5 +-- .../es93/ES93GenericFlatVectorsFormat.java | 19 +++++------- .../ES93HnswBinaryQuantizedVectorsFormat.java | 7 +++-- .../vectors/es93/ES93HnswVectorsFormat.java | 7 +++-- .../vectors/DenseVectorFieldMapper.java | 31 +++---------------- ...ryQuantizedBFloat16VectorsFormatTests.java | 5 +-- ...ES93BinaryQuantizedVectorsFormatTests.java | 5 +-- .../ES93HnswBFloat16VectorsFormatTests.java | 7 +++-- ...ryQuantizedBFloat16VectorsFormatTests.java | 7 +++-- ...HnswBinaryQuantizedVectorsFormatTests.java | 7 +++-- .../es93/ES93HnswBitVectorsFormatTests.java | 3 +- .../es93/ES93HnswVectorsFormatTests.java | 7 +++-- .../vectors/RescoreKnnVectorQueryTests.java | 7 ++--- 13 files changed, 50 insertions(+), 67 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java index ed224c82a5aaa..290b010fef78c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java @@ -30,6 +30,7 @@ import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; @@ -97,10 +98,10 @@ public class ES93BinaryQuantizedVectorsFormat extends AbstractFlatVectorsFormat private final ES93GenericFlatVectorsFormat rawFormat; public ES93BinaryQuantizedVectorsFormat() { - this(ES93GenericFlatVectorsFormat.ElementType.STANDARD, false); + this(DenseVectorFieldMapper.ElementType.FLOAT, false); } - public ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) { + public ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { super(NAME); rawFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index 191da9fd24778..1f7ddaf764d9c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -17,18 +17,13 @@ import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Map; public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { - public enum ElementType { - STANDARD, - BIT, // only supports byte[] - BFLOAT16 // only supports float[] - } - static final String NAME = "ES93GenericFlatVectorsFormat"; static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi"; static final String META_CODEC_NAME = "ES93GenericFlatVectorsFormatMeta"; @@ -43,7 +38,7 @@ public enum ElementType { VERSION_CURRENT ); - private static final DirectIOCapableFlatVectorsFormat standardVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + private static final DirectIOCapableFlatVectorsFormat defaultVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); private static final DirectIOCapableFlatVectorsFormat bitVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( @@ -62,8 +57,8 @@ public String getName() { private static final Map supportedFormats = Map.of( bitVectorFormat.getName(), bitVectorFormat, - standardVectorFormat.getName(), - standardVectorFormat, + defaultVectorFormat.getName(), + defaultVectorFormat, bfloat16VectorFormat.getName(), bfloat16VectorFormat ); @@ -72,13 +67,13 @@ public String getName() { private final boolean useDirectIO; public ES93GenericFlatVectorsFormat() { - this(ElementType.STANDARD, false); + this(DenseVectorFieldMapper.ElementType.FLOAT, false); } - public ES93GenericFlatVectorsFormat(ElementType elementType, boolean useDirectIO) { + public ES93GenericFlatVectorsFormat(DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { super(NAME); writeFormat = switch (elementType) { - case STANDARD -> standardVectorFormat; + case FLOAT, BYTE -> defaultVectorFormat; case BIT -> bitVectorFormat; case BFLOAT16 -> bfloat16VectorFormat; }; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index 78f356f8762da..15dce38cb742c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -27,6 +27,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.concurrent.ExecutorService; @@ -49,7 +50,7 @@ public ES93HnswBinaryQuantizedVectorsFormat() { * * @param useDirectIO whether to use direct IO when reading raw vectors */ - public ES93HnswBinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) { + public ES93HnswBinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { super(NAME); flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO); } @@ -64,7 +65,7 @@ public ES93HnswBinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.Element public ES93HnswBinaryQuantizedVectorsFormat( int maxConn, int beamWidth, - ES93GenericFlatVectorsFormat.ElementType elementType, + DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO ) { super(NAME, maxConn, beamWidth); @@ -85,7 +86,7 @@ public ES93HnswBinaryQuantizedVectorsFormat( public ES93HnswBinaryQuantizedVectorsFormat( int maxConn, int beamWidth, - ES93GenericFlatVectorsFormat.ElementType elementType, + DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO, int numMergeWorkers, ExecutorService mergeExec diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java index ab289e9a0f858..bfa14632b4a4a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java @@ -17,6 +17,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.concurrent.ExecutorService; @@ -32,12 +33,12 @@ public ES93HnswVectorsFormat() { flatVectorsFormat = new ES93GenericFlatVectorsFormat(); } - public ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType) { + public ES93HnswVectorsFormat(DenseVectorFieldMapper.ElementType elementType) { super(NAME); flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } - public ES93HnswVectorsFormat(int maxConn, int beamWidth, ES93GenericFlatVectorsFormat.ElementType elementType) { + public ES93HnswVectorsFormat(int maxConn, int beamWidth, DenseVectorFieldMapper.ElementType elementType) { super(NAME, maxConn, beamWidth); flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } @@ -45,7 +46,7 @@ public ES93HnswVectorsFormat(int maxConn, int beamWidth, ES93GenericFlatVectorsF public ES93HnswVectorsFormat( int maxConn, int beamWidth, - ES93GenericFlatVectorsFormat.ElementType elementType, + DenseVectorFieldMapper.ElementType elementType, int numMergeWorkers, ExecutorService mergeExec ) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 0ff09e85c2b94..d791d77794e77 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -55,7 +55,6 @@ import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; @@ -2014,11 +2013,7 @@ public static class HnswIndexOptions extends DenseVectorIndexOptions { @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType) { - return switch (elementType) { - case BIT -> new ES93HnswVectorsFormat(m, efConstruction, ES93GenericFlatVectorsFormat.ElementType.BIT); - case BYTE, FLOAT -> new ES93HnswVectorsFormat(m, efConstruction, ES93GenericFlatVectorsFormat.ElementType.STANDARD); - case BFLOAT16 -> new ES93HnswVectorsFormat(m, efConstruction, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16); - }; + return new ES93HnswVectorsFormat(m, efConstruction, elementType); } @Override @@ -2092,21 +2087,8 @@ public BBQHnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, Res @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - return switch (elementType) { - case FLOAT -> new ES93HnswBinaryQuantizedVectorsFormat( - m, - efConstruction, - ES93GenericFlatVectorsFormat.ElementType.STANDARD, - onDiskRescore - ); - case BFLOAT16 -> new ES93HnswBinaryQuantizedVectorsFormat( - m, - efConstruction, - ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, - onDiskRescore - ); - case BYTE, BIT -> throw new AssertionError(); - }; + assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; + return new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, elementType, onDiskRescore); } @Override @@ -2171,11 +2153,8 @@ static class BBQFlatIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - return switch (elementType) { - case FLOAT -> new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, false); - case BFLOAT16 -> new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, false); - case BYTE, BIT -> throw new AssertionError(); - }; + assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; + return new ES93BinaryQuantizedVectorsFormat(elementType, false); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java index 6cb1f7e61af5f..fb287e39b37d1 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java @@ -46,6 +46,7 @@ import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.BaseBFloat16KnnVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.AssumptionViolatedException; import java.io.IOException; @@ -75,7 +76,7 @@ public class ES93BinaryQuantizedBFloat16VectorsFormatTests extends BaseBFloat16K @Override public void setUp() throws Exception { - format = new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, random().nextBoolean()); + format = new ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16, random().nextBoolean()); super.setUp(); } @@ -196,7 +197,7 @@ public void testToString() { var defaultScorer = expected.replaceAll("\\{}", "DefaultFlatVectorScorer"); var memSegScorer = expected.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); - KnnVectorsFormat format = new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, false); + KnnVectorsFormat format = new ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16, false); assertThat(format, hasToString(oneOf(defaultScorer, memSegScorer))); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java index b9d53944ac654..39439bac02a8c 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -55,6 +55,7 @@ import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.AssumptionViolatedException; import java.io.IOException; @@ -84,7 +85,7 @@ public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatT @Override public void setUp() throws Exception { - format = new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, random().nextBoolean()); + format = new ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean()); super.setUp(); } @@ -201,7 +202,7 @@ public void testToString() { var defaultScorer = expected.replaceAll("\\{}", "DefaultFlatVectorScorer"); var memSegScorer = expected.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); - KnnVectorsFormat format = new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, false); + KnnVectorsFormat format = new ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, false); assertThat(format, hasToString(oneOf(defaultScorer, memSegScorer))); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java index 46956a279fc26..0220ba831e75d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java @@ -13,6 +13,7 @@ import org.apache.lucene.store.Directory; import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.BaseHnswBFloat16VectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Locale; @@ -30,17 +31,17 @@ public class ES93HnswBFloat16VectorsFormatTests extends BaseHnswBFloat16VectorsF @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16); + return new ES93HnswVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { - return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16); + return new ES93HnswVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.BFLOAT16); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { - return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, numMergeWorkers, service); + return new ES93HnswVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.BFLOAT16, numMergeWorkers, service); } public void testToString() { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java index 49dd1ba7c64e4..8d09eafb81ca2 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -16,6 +16,7 @@ import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.BaseHnswBFloat16VectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Locale; @@ -34,7 +35,7 @@ public class ES93HnswBinaryQuantizedBFloat16VectorsFormatTests extends BaseHnswB @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, random().nextBoolean()); + return new ES93HnswBinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16, random().nextBoolean()); } @Override @@ -42,7 +43,7 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { return new ES93HnswBinaryQuantizedVectorsFormat( maxConn, beamWidth, - ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, + DenseVectorFieldMapper.ElementType.BFLOAT16, random().nextBoolean() ); } @@ -52,7 +53,7 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMerge return new ES93HnswBinaryQuantizedVectorsFormat( maxConn, beamWidth, - ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, + DenseVectorFieldMapper.ElementType.BFLOAT16, random().nextBoolean(), numMergeWorkers, service diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index 8afc974cfca7d..75235d894bb74 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -25,6 +25,7 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Locale; @@ -43,7 +44,7 @@ public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseHnswVectorsFo @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, random().nextBoolean()); + return new ES93HnswBinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean()); } @Override @@ -51,7 +52,7 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { return new ES93HnswBinaryQuantizedVectorsFormat( maxConn, beamWidth, - ES93GenericFlatVectorsFormat.ElementType.STANDARD, + DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean() ); } @@ -61,7 +62,7 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMerge return new ES93HnswBinaryQuantizedVectorsFormat( maxConn, beamWidth, - ES93GenericFlatVectorsFormat.ElementType.STANDARD, + DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean(), numMergeWorkers, service diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java index 2d25d04289cab..b54db35d77273 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java @@ -23,6 +23,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.index.codec.vectors.BaseKnnBitVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.Before; import java.io.IOException; @@ -36,7 +37,7 @@ public class ES93HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCa @Override protected Codec getCodec() { - return TestUtil.alwaysKnnVectorsFormat(new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BIT)); + return TestUtil.alwaysKnnVectorsFormat(new ES93HnswVectorsFormat(DenseVectorFieldMapper.ElementType.BIT)); } @Before diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java index 3fe512038cad5..84057c7709063 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.store.Directory; import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Locale; @@ -29,17 +30,17 @@ public class ES93HnswVectorsFormatTests extends BaseHnswVectorsFormatTestCase { @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD); + return new ES93HnswVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { - return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.STANDARD); + return new ES93HnswVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.FLOAT); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { - return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.STANDARD, numMergeWorkers, service); + return new ES93HnswVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.FLOAT, numMergeWorkers, service); } public void testToString() { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 337c3875deb42..9d510e16e6f32 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -44,7 +44,6 @@ import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -291,15 +290,15 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th randomBoolean() ), new ES93BinaryQuantizedVectorsFormat( - randomFrom(ES93GenericFlatVectorsFormat.ElementType.STANDARD, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16), + randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), randomBoolean() ), new ES93HnswBinaryQuantizedVectorsFormat( - randomFrom(ES93GenericFlatVectorsFormat.ElementType.STANDARD, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16), + randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), randomBoolean() ), new ES93HnswBinaryQuantizedVectorsFormat( - randomFrom(ES93GenericFlatVectorsFormat.ElementType.STANDARD, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16), + randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), randomBoolean() ), new ES813Int8FlatVectorFormat(), From a87ffd7de5ce2ff12e88c9af680f60c7901f84c4 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 30 Oct 2025 11:39:32 +0000 Subject: [PATCH 32/46] Update docs/changelog/135940.yaml --- docs/changelog/135940.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/135940.yaml diff --git a/docs/changelog/135940.yaml b/docs/changelog/135940.yaml new file mode 100644 index 0000000000000..d754c98f320f5 --- /dev/null +++ b/docs/changelog/135940.yaml @@ -0,0 +1,5 @@ +pr: 135940 +summary: Enable directIO and bfloat16 for bbq and unquantized vector field types +area: Vector Search +type: feature +issues: [] From e871fbc2702e117ddb01f9e11a49d83325bbad2c Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 30 Oct 2025 12:03:13 +0000 Subject: [PATCH 33/46] Some updates --- .../main/java/org/elasticsearch/index/IndexVersions.java | 1 + .../codec/vectors/es93/ES93GenericFlatVectorsFormat.java | 4 ++-- .../search/vectors/RescoreKnnVectorQueryTests.java | 4 ---- .../org/elasticsearch/xpack/inference/model/TestModel.java | 7 ++++++- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index e63b655e2ce8d..73b891ded6f34 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -192,6 +192,7 @@ private static Version parseUnchecked(String version) { public static final IndexVersion REENABLED_TIMESTAMP_DOC_VALUES_SPARSE_INDEX = def(9_042_0_00, Version.LUCENE_10_3_1); public static final IndexVersion SKIPPERS_ENABLED_BY_DEFAULT = def(9_043_0_00, Version.LUCENE_10_3_1); + public static final IndexVersion BFLOAT16_HNSW_SUPPORT = def(9_044_0_00, Version.LUCENE_10_3_1); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index 1f7ddaf764d9c..0eedc3f16cdfb 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -55,10 +55,10 @@ public String getName() { ); private static final Map supportedFormats = Map.of( - bitVectorFormat.getName(), - bitVectorFormat, defaultVectorFormat.getName(), defaultVectorFormat, + bitVectorFormat.getName(), + bitVectorFormat, bfloat16VectorFormat.getName(), bfloat16VectorFormat ); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 9d510e16e6f32..38b3631862690 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -297,10 +297,6 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), randomBoolean() ), - new ES93HnswBinaryQuantizedVectorsFormat( - randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), - randomBoolean() - ), new ES813Int8FlatVectorFormat(), new ES813Int8FlatVectorFormat(), new ES814HnswScalarQuantizedVectorsFormat() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index f99e8ce562b42..5eb6d41be1bd0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -54,7 +54,12 @@ public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { if (taskType == TaskType.TEXT_EMBEDDING) { - var elementType = randomFrom(DenseVectorFieldMapper.ElementType.values()); + // TODO: bfloat16 + var elementType = randomFrom( + DenseVectorFieldMapper.ElementType.FLOAT, + DenseVectorFieldMapper.ElementType.BYTE, + DenseVectorFieldMapper.ElementType.BIT + ); var dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions); List supportedSimilarities = new ArrayList<>( From 4f2f6733265bd9e3acfef2f057ae451c85a76a44 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 30 Oct 2025 15:39:17 +0000 Subject: [PATCH 34/46] Provide some more implementations --- .../mapper/vectors/VectorDVLeafFieldData.java | 206 ++++++++++-------- ...loat16BinaryDenseVectorDocValuesField.java | 38 ++++ .../BinaryDenseVectorDocValuesField.java | 6 +- 3 files changed, 160 insertions(+), 90 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java index f0e61cb38b4bc..f8729ac0661ec 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java @@ -17,11 +17,13 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.fielddata.FormattedDocValues; import org.elasticsearch.index.fielddata.LeafFieldData; import org.elasticsearch.index.fielddata.SortedBinaryDocValues; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; +import org.elasticsearch.script.field.vectors.BFloat16BinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.BitBinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.BitKnnDenseVectorDocValuesField; @@ -31,6 +33,9 @@ import org.elasticsearch.search.DocValueFormat; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ShortBuffer; import java.util.Arrays; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -69,6 +74,7 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { if (indexed) { return switch (elementType) { case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); + // bfloat16 is hidden by the FloatVectorValues implementation case FLOAT, BFLOAT16 -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); }; @@ -76,7 +82,8 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { BinaryDocValues values = DocValues.getBinary(reader, field); return switch (elementType) { case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims); - case FLOAT, BFLOAT16 -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); + case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); + case BFLOAT16 -> new BFloat16BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims); }; } @@ -85,105 +92,126 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { } } - @Override - public FormattedDocValues getFormattedValues(DocValueFormat format) { - int dims = elementType == ElementType.BIT ? this.dims / Byte.SIZE : this.dims; - return switch (elementType) { - case BYTE, BIT -> new FormattedDocValues() { - private byte[] vector = new byte[dims]; - private ByteVectorValues byteVectorValues; // use when indexed - private KnnVectorValues.DocIndexIterator iterator; // use when indexed - private BinaryDocValues binary; // use when not indexed - { - try { - if (indexed) { - byteVectorValues = reader.getByteVectorValues(field); - iterator = (byteVectorValues == null) ? null : byteVectorValues.iterator(); - } else { - binary = DocValues.getBinary(reader, field); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values", e); - } - + private class ByteDocValues implements FormattedDocValues { + private byte[] vector; + private ByteVectorValues byteVectorValues; // use when indexed + private KnnVectorValues.DocIndexIterator iterator; // use when indexed + private BinaryDocValues binary; // use when not indexed + + ByteDocValues(int dims) { + this.vector = new byte[dims]; + try { + if (indexed) { + byteVectorValues = reader.getByteVectorValues(field); + iterator = (byteVectorValues == null) ? null : byteVectorValues.iterator(); + } else { + binary = DocValues.getBinary(reader, field); } + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values", e); + } - @Override - public boolean advanceExact(int docId) throws IOException { - if (indexed) { - if (iteratorAdvanceExact(iterator, docId) == false) { - return false; - } - vector = byteVectorValues.vectorValue(iterator.index()); - } else { - if (binary == null || binary.advanceExact(docId) == false) { - return false; - } - BytesRef ref = binary.binaryValue(); - System.arraycopy(ref.bytes, ref.offset, vector, 0, dims); - } - return true; - } + } - @Override - public int docValueCount() { - return 1; + @Override + public boolean advanceExact(int docId) throws IOException { + if (indexed) { + if (iteratorAdvanceExact(iterator, docId) == false) { + return false; } - - public Object nextValue() { - Byte[] vectorValue = new Byte[dims]; - for (int i = 0; i < dims; i++) { - vectorValue[i] = vector[i]; - } - return vectorValue; + vector = byteVectorValues.vectorValue(iterator.index()); + } else { + if (binary == null || binary.advanceExact(docId) == false) { + return false; } - }; - case FLOAT, BFLOAT16 -> new FormattedDocValues() { - float[] vector = new float[dims]; - private FloatVectorValues floatVectorValues; // use when indexed - private KnnVectorValues.DocIndexIterator iterator; // use when indexed - private BinaryDocValues binary; // use when not indexed - { - try { - if (indexed) { - floatVectorValues = reader.getFloatVectorValues(field); - iterator = (floatVectorValues == null) ? null : floatVectorValues.iterator(); - } else { - binary = DocValues.getBinary(reader, field); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values", e); - } + BytesRef ref = binary.binaryValue(); + System.arraycopy(ref.bytes, ref.offset, vector, 0, dims); + } + return true; + } - } + @Override + public int docValueCount() { + return 1; + } - @Override - public boolean advanceExact(int docId) throws IOException { - if (indexed) { - if (iteratorAdvanceExact(iterator, docId) == false) { - return false; - } - vector = floatVectorValues.vectorValue(iterator.index()); - } else { - if (binary == null || binary.advanceExact(docId) == false) { - return false; - } - BytesRef ref = binary.binaryValue(); - VectorEncoderDecoder.decodeDenseVector(indexVersion, ref, vector); - } - return true; - } + public Object nextValue() { + Byte[] vectorValue = new Byte[dims]; + for (int i = 0; i < dims; i++) { + vectorValue[i] = vector[i]; + } + return vectorValue; + } + } - @Override - public int docValueCount() { - return 1; + private class FloatDocValues implements FormattedDocValues { + private float[] vector = new float[dims]; + private FloatVectorValues floatVectorValues; // use when indexed + private KnnVectorValues.DocIndexIterator iterator; // use when indexed + private BinaryDocValues binary; // use when not indexed + + FloatDocValues() { + try { + if (indexed) { + floatVectorValues = reader.getFloatVectorValues(field); + iterator = (floatVectorValues == null) ? null : floatVectorValues.iterator(); + } else { + binary = DocValues.getBinary(reader, field); } + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values", e); + } + } - @Override - public Object nextValue() { - return Arrays.copyOf(vector, vector.length); + @Override + public boolean advanceExact(int docId) throws IOException { + if (indexed) { + if (iteratorAdvanceExact(iterator, docId) == false) { + return false; + } + vector = floatVectorValues.vectorValue(iterator.index()); + } else { + if (binary == null || binary.advanceExact(docId) == false) { + return false; } - }; + BytesRef ref = binary.binaryValue(); + decodeDenseVector(indexVersion, ref, vector); + } + return true; + } + + void decodeDenseVector(IndexVersion indexVersion, BytesRef ref, float[] vector) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, ref, vector); + } + + @Override + public int docValueCount() { + return 1; + } + + @Override + public Object nextValue() { + return Arrays.copyOf(vector, vector.length); + } + } + + private class BFloat16DocValues extends FloatDocValues { + @Override + void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorBR, float[] vector) { + ShortBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length) + .order(ByteOrder.LITTLE_ENDIAN) + .asShortBuffer(); + BFloat16.bFloat16ToFloat(fb, vector); + } + } + + @Override + public FormattedDocValues getFormattedValues(DocValueFormat format) { + return switch (elementType) { + case BYTE -> new ByteDocValues(dims); + case BIT -> new ByteDocValues(dims / Byte.SIZE); + case FLOAT -> new FloatDocValues(); + case BFLOAT16 -> new BFloat16DocValues(); }; } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..dc540eb861890 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ShortBuffer; + +public class BFloat16BinaryDenseVectorDocValuesField extends BinaryDenseVectorDocValuesField { + public BFloat16BinaryDenseVectorDocValuesField( + BinaryDocValues input, + String name, + DenseVectorFieldMapper.ElementType elementType, + int dims, + IndexVersion indexVersion + ) { + super(input, name, elementType, dims, indexVersion); + } + + @Override + void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorBR, float[] vector) { + ShortBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer(); + BFloat16.bFloat16ToFloat(fb, vector); + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java index 0bb9d2a3a0b0d..376d04cbeb957 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java @@ -86,8 +86,12 @@ public DenseVector getInternal() { private void decodeVectorIfNecessary() { if (decoded == false && value != null) { - VectorEncoderDecoder.decodeDenseVector(indexVersion, value, vectorValue); + decodeDenseVector(indexVersion, value, vectorValue); decoded = true; } } + + void decodeDenseVector(IndexVersion indexVersion, BytesRef value, float[] vector) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, value, vectorValue); + } } From 8395687c8b0a83154b91dece4281449c28dca683 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 30 Oct 2025 16:32:08 +0000 Subject: [PATCH 35/46] Test and fixes --- .../200_dense_vector_docvalue_fields.yml | 102 +++++++++++++++++- .../mapper/vectors/VectorDVLeafFieldData.java | 2 + 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml index 161fc23a84651..1861e021c6b34 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml @@ -123,7 +123,6 @@ setup: - match: {hits.hits.0.fields.vector5.0: [1, 111, -13, 15, -128]} - match: {hits.hits.0.fields.vector6.0: [-1, 11, 0, 12, 111]} - - match: {hits.hits.1._id: "2"} - match: {hits.hits.1.fields.name.0: "moose.jpg"} @@ -143,7 +142,6 @@ setup: - match: {hits.hits.1.fields.vector4.0: [-1, 50, -1, 1, 120]} - match: {hits.hits.1.fields.vector5.0: [1, 111, -13, 15, -128]} - - match: {hits.hits.2._id: "3"} - match: {hits.hits.2.fields.name.0: "rabbit.jpg"} @@ -161,3 +159,103 @@ setup: - close_to: { hits.hits.2.fields.vector2.0.4: { value: -100.0, error: 0.001 } } - match: {hits.hits.2.fields.vector3.0: [-1, 100, -13, 15, -128]} + +--- +"dense_vector docvalues with bfloat16": + - requires: + cluster_features: [ "mapper.vectors.hnsw_bfloat16_on_disk_rescoring" ] + reason: Needs bfloat16 support + - do: + indices.create: + index: test-bfloat16 + body: + mappings: + properties: + name: + type: keyword + vector7: + type: dense_vector + element_type: bfloat16 + dims: 5 + index: true + vector8: + type: dense_vector + element_type: bfloat16 + dims: 5 + index: false + + - do: + index: + index: test-bfloat16 + id: "1" + body: + name: cow.jpg + vector7: [230.0, 300.33, -34.8988, 15.555, -200.0] + vector8: [130.0, 115.0, -1.02, 15.555, -100.0] + - do: + index: + index: test-bfloat16 + id: "2" + body: + name: moose.jpg + vector7: [-0.5, 100.0, -13, 14.8, -156.0] + - do: + index: + index: test-bfloat16 + id: "3" + body: + name: rabbit.jpg + vector8: [130.0, 115.0, -1.02, 15.555, -100.0] + + - do: + indices.refresh: {} + + - do: + search: + _source: false + index: test-bfloat16 + body: + docvalue_fields: [name, vector7, vector8] + sort: name + + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.0.fields.name.0: "cow.jpg"} + + - length: {hits.hits.0.fields.vector7.0: 5} + - length: {hits.hits.0.fields.vector8.0: 5} + + - close_to: { hits.hits.0.fields.vector7.0.0: { value: 230.0, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector7.0.1: { value: 300.33, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector7.0.2: { value: -34.8988, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector7.0.3: { value: 15.555, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector7.0.4: { value: -200.0, error: 0.01 } } + + - close_to: { hits.hits.0.fields.vector8.0.0: { value: 130.0, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector8.0.1: { value: 115.0, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector8.0.2: { value: -1.02, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector8.0.3: { value: 15.555, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector8.0.4: { value: -100.0, error: 0.01 } } + + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.1.fields.name.0: "moose.jpg"} + + - length: {hits.hits.1.fields.vector7.0: 5} + - match: {hits.hits.1.fields.vector8: null} + + - close_to: { hits.hits.1.fields.vector7.0.0: { value: -0.5, error: 0.01 } } + - close_to: { hits.hits.1.fields.vector7.0.1: { value: 100.0, error: 0.01 } } + - close_to: { hits.hits.1.fields.vector7.0.2: { value: -13, error: 0.01 } } + - close_to: { hits.hits.1.fields.vector7.0.3: { value: 14.8, error: 0.01 } } + - close_to: { hits.hits.1.fields.vector7.0.4: { value: -156.0, error: 0.01 } } + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2.fields.name.0: "rabbit.jpg"} + + - length: {hits.hits.2.fields.vector8.0: 5} + - match: {hits.hits.2.fields.vector7: null} + + - close_to: { hits.hits.2.fields.vector8.0.0: { value: 130.0, error: 0.01 } } + - close_to: { hits.hits.2.fields.vector8.0.1: { value: 115.0, error: 0.01 } } + - close_to: { hits.hits.2.fields.vector8.0.2: { value: -1.02, error: 0.01 } } + - close_to: { hits.hits.2.fields.vector8.0.3: { value: 15.555, error: 0.01 } } + - close_to: { hits.hits.2.fields.vector8.0.4: { value: -100.0, error: 0.01 } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java index f8729ac0661ec..2edd7d89a3778 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java @@ -93,12 +93,14 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { } private class ByteDocValues implements FormattedDocValues { + private final int dims; private byte[] vector; private ByteVectorValues byteVectorValues; // use when indexed private KnnVectorValues.DocIndexIterator iterator; // use when indexed private BinaryDocValues binary; // use when not indexed ByteDocValues(int dims) { + this.dims = dims; this.vector = new byte[dims]; try { if (indexed) { From 9eaa1524fdeda85e480635e15a9e8e26a4356aab Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 31 Oct 2025 10:14:31 +0000 Subject: [PATCH 36/46] PR comments --- .../200_dense_vector_docvalue_fields.yml | 40 +++++++++---------- .../index/codec/vectors/BFloat16.java | 2 - .../DenseVectorFromBinaryBlockLoader.java | 9 +---- .../vectors/DenseVectorFieldMapper.java | 19 +-------- .../mapper/vectors/VectorDVLeafFieldData.java | 9 +---- .../mapper/vectors/VectorEncoderDecoder.java | 10 +++++ ...loat16BinaryDenseVectorDocValuesField.java | 9 +---- .../mapper/RankVectorsFieldMapper.java | 5 ++- 8 files changed, 39 insertions(+), 64 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml index 1861e021c6b34..59b9840fdfd28 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml @@ -224,17 +224,17 @@ setup: - length: {hits.hits.0.fields.vector7.0: 5} - length: {hits.hits.0.fields.vector8.0: 5} - - close_to: { hits.hits.0.fields.vector7.0.0: { value: 230.0, error: 0.01 } } - - close_to: { hits.hits.0.fields.vector7.0.1: { value: 300.33, error: 0.01 } } - - close_to: { hits.hits.0.fields.vector7.0.2: { value: -34.8988, error: 0.01 } } - - close_to: { hits.hits.0.fields.vector7.0.3: { value: 15.555, error: 0.01 } } - - close_to: { hits.hits.0.fields.vector7.0.4: { value: -200.0, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector7.0.0: { value: 230.0, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector7.0.1: { value: 300.33, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector7.0.2: { value: -34.8988, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector7.0.3: { value: 15.555, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector7.0.4: { value: -200.0, error: 0.1 } } - - close_to: { hits.hits.0.fields.vector8.0.0: { value: 130.0, error: 0.01 } } - - close_to: { hits.hits.0.fields.vector8.0.1: { value: 115.0, error: 0.01 } } - - close_to: { hits.hits.0.fields.vector8.0.2: { value: -1.02, error: 0.01 } } - - close_to: { hits.hits.0.fields.vector8.0.3: { value: 15.555, error: 0.01 } } - - close_to: { hits.hits.0.fields.vector8.0.4: { value: -100.0, error: 0.01 } } + - close_to: { hits.hits.0.fields.vector8.0.0: { value: 130.0, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector8.0.1: { value: 115.0, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector8.0.2: { value: -1.02, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector8.0.3: { value: 15.555, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector8.0.4: { value: -100.0, error: 0.1 } } - match: {hits.hits.1._id: "2"} - match: {hits.hits.1.fields.name.0: "moose.jpg"} @@ -242,11 +242,11 @@ setup: - length: {hits.hits.1.fields.vector7.0: 5} - match: {hits.hits.1.fields.vector8: null} - - close_to: { hits.hits.1.fields.vector7.0.0: { value: -0.5, error: 0.01 } } - - close_to: { hits.hits.1.fields.vector7.0.1: { value: 100.0, error: 0.01 } } - - close_to: { hits.hits.1.fields.vector7.0.2: { value: -13, error: 0.01 } } - - close_to: { hits.hits.1.fields.vector7.0.3: { value: 14.8, error: 0.01 } } - - close_to: { hits.hits.1.fields.vector7.0.4: { value: -156.0, error: 0.01 } } + - close_to: { hits.hits.1.fields.vector7.0.0: { value: -0.5, error: 0.1 } } + - close_to: { hits.hits.1.fields.vector7.0.1: { value: 100.0, error: 0.1 } } + - close_to: { hits.hits.1.fields.vector7.0.2: { value: -13, error: 0.1 } } + - close_to: { hits.hits.1.fields.vector7.0.3: { value: 14.8, error: 0.1 } } + - close_to: { hits.hits.1.fields.vector7.0.4: { value: -156.0, error: 0.1 } } - match: {hits.hits.2._id: "3"} - match: {hits.hits.2.fields.name.0: "rabbit.jpg"} @@ -254,8 +254,8 @@ setup: - length: {hits.hits.2.fields.vector8.0: 5} - match: {hits.hits.2.fields.vector7: null} - - close_to: { hits.hits.2.fields.vector8.0.0: { value: 130.0, error: 0.01 } } - - close_to: { hits.hits.2.fields.vector8.0.1: { value: 115.0, error: 0.01 } } - - close_to: { hits.hits.2.fields.vector8.0.2: { value: -1.02, error: 0.01 } } - - close_to: { hits.hits.2.fields.vector8.0.3: { value: 15.555, error: 0.01 } } - - close_to: { hits.hits.2.fields.vector8.0.4: { value: -100.0, error: 0.01 } } + - close_to: { hits.hits.2.fields.vector8.0.0: { value: 130.0, error: 0.1 } } + - close_to: { hits.hits.2.fields.vector8.0.1: { value: 115.0, error: 0.1 } } + - close_to: { hits.hits.2.fields.vector8.0.2: { value: -1.02, error: 0.1 } } + - close_to: { hits.hits.2.fields.vector8.0.3: { value: 15.555, error: 0.1 } } + - close_to: { hits.hits.2.fields.vector8.0.4: { value: -100.0, error: 0.1 } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java index 11eaf69344c97..8b50a39fe01af 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java @@ -38,7 +38,6 @@ public static float bFloat16ToFloat(short bf) { } public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) { - assert bFloats.remaining() == floats.length; assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; for (float v : floats) { bFloats.put(floatToBFloat16(v)); @@ -53,7 +52,6 @@ public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) { } public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) { - assert floats.length == bFloats.remaining(); assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; for (int i = 0; i < floats.length; i++) { floats[i] = bFloat16ToFloat(bFloats.get()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java index 4bb94104f4f63..7e729eca933e5 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java @@ -13,15 +13,11 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.ShortBuffer; public class DenseVectorFromBinaryBlockLoader extends BlockDocValuesReader.DocValuesBlockLoader { private final String fieldName; @@ -151,10 +147,7 @@ protected void writeScratchToBuilder(float[] scratch, BlockLoader.FloatBuilder b @Override protected void decodeDenseVector(BytesRef bytesRef, float[] scratch) { - ShortBuffer sb = ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length) - .order(ByteOrder.LITTLE_ENDIAN) - .asShortBuffer(); - BFloat16.bFloat16ToFloat(sb, scratch); + VectorEncoderDecoder.decodeBFloat16DenseVector(bytesRef, scratch); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 46b829eea7097..d91f21dbede50 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -532,8 +532,6 @@ public static ElementType checkValidVector(float[] vector, ElementType... possib public abstract ElementType elementType(); - public abstract void writeValue(ByteBuffer byteBuffer, float value); - public abstract void writeValues(ByteBuffer byteBuffer, float[] values); public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException; @@ -636,11 +634,6 @@ public ElementType elementType() { return ElementType.BYTE; } - @Override - public void writeValue(ByteBuffer byteBuffer, float value) { - byteBuffer.put((byte) value); - } - @Override public void writeValues(ByteBuffer byteBuffer, float[] values) { for (float f : values) { @@ -895,11 +888,6 @@ public ElementType elementType() { return ElementType.FLOAT; } - @Override - public void writeValue(ByteBuffer byteBuffer, float value) { - byteBuffer.putFloat(value); - } - @Override public void writeValues(ByteBuffer byteBuffer, float[] values) { byteBuffer.asFloatBuffer().put(values); @@ -1086,14 +1074,9 @@ public ElementType elementType() { return ElementType.BFLOAT16; } - @Override - public void writeValue(ByteBuffer byteBuffer, float value) { - byteBuffer.putShort(BFloat16.floatToBFloat16(value)); - } - @Override public void writeValues(ByteBuffer byteBuffer, float[] values) { - BFloat16.floatToBFloat16(values, byteBuffer.asShortBuffer().limit(values.length)); + BFloat16.floatToBFloat16(values, byteBuffer.asShortBuffer()); byteBuffer.position(byteBuffer.position() + (values.length * BFloat16.BYTES)); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java index 2edd7d89a3778..8363f32f8f27f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java @@ -17,7 +17,6 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.fielddata.FormattedDocValues; import org.elasticsearch.index.fielddata.LeafFieldData; import org.elasticsearch.index.fielddata.SortedBinaryDocValues; @@ -33,9 +32,6 @@ import org.elasticsearch.search.DocValueFormat; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.ShortBuffer; import java.util.Arrays; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -200,10 +196,7 @@ public Object nextValue() { private class BFloat16DocValues extends FloatDocValues { @Override void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorBR, float[] vector) { - ShortBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length) - .order(ByteOrder.LITTLE_ENDIAN) - .asShortBuffer(); - BFloat16.bFloat16ToFloat(fb, vector); + VectorEncoderDecoder.decodeBFloat16DenseVector(vectorBR, vector); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java index 9dec4a4f2dd61..f60e7139cd3d9 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java @@ -12,10 +12,12 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.codec.vectors.BFloat16; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; +import java.nio.ShortBuffer; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION; @@ -84,6 +86,14 @@ public static void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorB } } + public static void decodeBFloat16DenseVector(BytesRef vectorBR, float[] vector) { + if (vectorBR == null) { + throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE); + } + ShortBuffer sb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer(); + BFloat16.bFloat16ToFloat(sb, vector); + } + /** * Decodes a BytesRef into the provided array of bytes * @param vectorBR - dense vector encoded in BytesRef diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java index dc540eb861890..4805b77250c00 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java @@ -12,12 +12,8 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.ShortBuffer; +import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; public class BFloat16BinaryDenseVectorDocValuesField extends BinaryDenseVectorDocValuesField { public BFloat16BinaryDenseVectorDocValuesField( @@ -32,7 +28,6 @@ public BFloat16BinaryDenseVectorDocValuesField( @Override void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorBR, float[] vector) { - ShortBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer(); - BFloat16.bFloat16ToFloat(fb, vector); + VectorEncoderDecoder.decodeBFloat16DenseVector(vectorBR, vector); } } diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java index 5c0ad8680a389..7f2bd456db4da 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java @@ -72,11 +72,14 @@ public static class Builder extends FieldMapper.Builder { () -> DenseVectorFieldMapper.ElementType.FLOAT, (n, c, o) -> { DenseVectorFieldMapper.ElementType elementType = namesToElementType.get((String) o); - if (elementType == null || elementType == ElementType.BFLOAT16) { + if (elementType == null) { throw new MapperParsingException( "invalid element_type [" + o + "]; available types are " + namesToElementType.keySet() ); } + if (elementType == ElementType.BFLOAT16) { + throw new MapperParsingException("Rank vectors does not support bfloat16"); + } return elementType; }, m -> toType(m).fieldType().element.elementType(), From 6bef1f0c22c6a2b4325e598c4524f2466a4d0859 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 5 Nov 2025 12:59:33 +0000 Subject: [PATCH 37/46] Add assert for an interface it shouldn't implement yet --- .../codec/vectors/es93/OffHeapBFloat16VectorValues.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java index 42f02d2d21366..ecae96b08c576 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java @@ -21,6 +21,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; @@ -63,6 +64,10 @@ abstract class OffHeapBFloat16VectorValues extends FloatVectorValues { this.flatVectorsScorer = flatVectorsScorer; bfloatBytes = new byte[dimension * BFloat16.BYTES]; value = new float[dimension]; + + assert (this instanceof HasIndexSlice) == false + : "BFloat16 should not implement HasIndexSlice until a bfloat16 scorer is created," + + " else Lucene99MemorySegmentFlatVectorsScorer will try to access 4-byte floats here"; } @Override From 1e905ded052f48f8dd4ffbd143f18ebf4c0396e1 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 5 Nov 2025 14:28:06 +0000 Subject: [PATCH 38/46] Add case --- .../xpack/esql/DenseVectorFieldTypeIT.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index d852ddeaf1d8b..ff720fed2210c 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.script.field.vectors.DenseVector; @@ -250,6 +251,15 @@ public void setup() throws IOException { buffer.asFloatBuffer().put(array); yield Base64.getEncoder().encodeToString(buffer.array()); } + case BFLOAT16 -> { + float[] array = new float[numDims]; + for (int k = 0; k < numDims; k++) { + array[k] = vector.get(k).floatValue(); + } + final ByteBuffer buffer = ByteBuffer.allocate(BFloat16.BYTES * numDims); + BFloat16.floatToBFloat16(array, buffer.asShortBuffer()); + yield Base64.getEncoder().encodeToString(buffer.array()); + } case BYTE, BIT -> { byte[] array = new byte[numDims]; for (int k = 0; k < numDims; k++) { From e2c2730dd254f39d17e7cc211d9e9e8c8cff5d9d Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 5 Nov 2025 15:43:06 +0000 Subject: [PATCH 39/46] Base64 support --- .../elasticsearch/index/codec/vectors/BFloat16.java | 3 --- .../docvalues/DenseVectorBlockLoader.java | 2 +- .../index/mapper/vectors/DenseVectorFieldMapper.java | 12 +++++++++--- .../elasticsearch/script/VectorScoreScriptUtils.java | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java index 8b50a39fe01af..8e558fe5f9191 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java @@ -11,7 +11,6 @@ import org.apache.lucene.util.BitUtil; -import java.nio.ByteOrder; import java.nio.ShortBuffer; public final class BFloat16 { @@ -38,7 +37,6 @@ public static float bFloat16ToFloat(short bf) { } public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) { - assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; for (float v : floats) { bFloats.put(floatToBFloat16(v)); } @@ -52,7 +50,6 @@ public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) { } public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) { - assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; for (int i = 0; i < floats.length; i++) { floats[i] = bFloat16ToFloat(bFloats.get()); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorBlockLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorBlockLoader.java index e5561428364de..32869f8743878 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorBlockLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorBlockLoader.java @@ -53,7 +53,7 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { switch (fieldType.getElementType()) { - case FLOAT -> { + case FLOAT, BFLOAT16 -> { FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); if (floatVectorValues != null) { if (fieldType.isNormalized()) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index e89de373ebb8c..738107845842e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1150,20 +1150,26 @@ VectorDataAndMagnitude parseBase64EncodedVector(DocumentParserContext context, I throws IOException { // BIG_ENDIAN is the default, but just being explicit here ByteBuffer byteBuffer = ByteBuffer.wrap(Base64.getDecoder().decode(context.parser().text())).order(ByteOrder.BIG_ENDIAN); - if (byteBuffer.remaining() != dims * Float.BYTES) { + float[] decodedVector = new float[dims]; + if (byteBuffer.remaining() == dims * Float.BYTES) { + byteBuffer.asFloatBuffer().get(decodedVector); + } else if (byteBuffer.remaining() == dims * BFloat16.BYTES) { + BFloat16.bFloat16ToFloat(byteBuffer.asShortBuffer(), decodedVector); + } else { throw new ParsingException( context.parser().getTokenLocation(), "Failed to parse object: Base64 decoded vector byte length [" + byteBuffer.remaining() + "] does not match the expected length of [" + (dims * Float.BYTES) + + "] or [" + + (dims * BFloat16.BYTES) + "] for dimension count [" + dims + "]" ); } - float[] decodedVector = new float[dims]; - byteBuffer.asFloatBuffer().get(decodedVector); + dimChecker.accept(decodedVector.length, true); VectorData vectorData = VectorData.fromFloats(decodedVector); float squaredMagnitude = (float) computeSquaredMagnitude(vectorData); diff --git a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java index b29a0266c220c..6932b5c718dc6 100644 --- a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java +++ b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java @@ -252,7 +252,7 @@ public static final class Hamming { @SuppressWarnings("unchecked") public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) { DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); - if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) { + if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT || field.getElementType() == ElementType.BFLOAT16) { throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors"); } if (queryVector instanceof List) { From ee46e003d41a5e6d839645409a4390abbcc7f115 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 5 Nov 2025 16:47:14 +0000 Subject: [PATCH 40/46] Add a feature flag for the new formats --- .../es93/ES93GenericFlatVectorsFormat.java | 3 +++ .../index/mapper/MapperFeatures.java | 12 +++++++--- .../vectors/DenseVectorFieldMapper.java | 23 +++++++++++++++---- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index 0eedc3f16cdfb..c70d40f5cb03a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -15,6 +15,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -24,6 +25,8 @@ public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { + public static final FeatureFlag ES93_VECTOR_FORMATS = new FeatureFlag("es93_vector_formats"); + static final String NAME = "ES93GenericFlatVectorsFormat"; static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi"; static final String META_CODEC_NAME = "ES93GenericFlatVectorsFormatMeta"; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 9294636055567..bb24068f499d1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -11,7 +11,9 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; +import java.util.HashSet; import java.util.Set; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING; @@ -63,7 +65,7 @@ public class MapperFeatures implements FeatureSpecification { @Override public Set getTestFeatures() { - return Set.of( + var features = Set.of( RangeFieldMapper.DATE_RANGE_INDEXING_FIX, IgnoredSourceFieldMapper.DONT_EXPAND_DOTS_IN_IGNORED_SOURCE, SourceFieldMapper.REMOVE_SYNTHETIC_SOURCE_ONLY_VALIDATION, @@ -102,8 +104,12 @@ public Set getTestFeatures() { PROVIDE_INDEX_SORT_SETTING_DEFAULTS, INDEX_MAPPING_IGNORE_DYNAMIC_BEYOND_FIELD_NAME_LIMIT, EXCLUDE_VECTORS_DOCVALUE_BUGFIX, - BASE64_DENSE_VECTORS, - HNSW_BFLOAT16_ON_DISK_RESCORING + BASE64_DENSE_VECTORS ); + if (ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled()) { + features = new HashSet<>(features); + features.add(HNSW_BFLOAT16_ON_DISK_RESCORING); + } + return features; } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 738107845842e..5512aec27ed92 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -54,7 +54,10 @@ import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; @@ -238,7 +241,8 @@ public static class Builder extends FieldMapper.Builder { private final Parameter elementType = new Parameter<>("element_type", false, () -> ElementType.FLOAT, (n, c, o) -> { ElementType elementType = namesToElementType.get((String) o); - if (elementType == null) { + if (elementType == null + || (elementType == ElementType.BFLOAT16 && ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled() == false)) { throw new MapperParsingException("invalid element_type [" + o + "]; available types are " + namesToElementType.keySet()); } return elementType; @@ -2145,7 +2149,14 @@ public static class HnswIndexOptions extends DenseVectorIndexOptions { @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType) { - return new ES93HnswVectorsFormat(m, efConstruction, elementType); + if (ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled()) { + return new ES93HnswVectorsFormat(m, efConstruction, elementType); + } else { + if (elementType == ElementType.BIT) { + return new ES815HnswBitVectorsFormat(m, efConstruction); + } + return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null); + } } @Override @@ -2220,7 +2231,9 @@ public BBQHnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, Res @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; - return new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, elementType, onDiskRescore); + return ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled() + ? new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, elementType, onDiskRescore) + : new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction); } @Override @@ -2286,7 +2299,9 @@ static class BBQFlatIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; - return new ES93BinaryQuantizedVectorsFormat(elementType, false); + return ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled() + ? new ES93BinaryQuantizedVectorsFormat(elementType, false) + : new ES818BinaryQuantizedVectorsFormat(); } @Override From dce8dd88a1a9c594e8fd24567e6ad06fe336d6c6 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 6 Nov 2025 10:20:36 +0000 Subject: [PATCH 41/46] Update the feature flag & cluster feature names --- docs/changelog/135940.yaml | 5 ----- .../search.vectors/200_dense_vector_docvalue_fields.yml | 2 +- .../test/search.vectors/41_knn_search_bbq_hnsw.yml | 2 +- .../search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml | 2 +- .../search.vectors/42_knn_search_bbq_flat_bfloat16.yml | 2 +- .../codec/vectors/es93/ES93GenericFlatVectorsFormat.java | 2 +- .../org/elasticsearch/index/mapper/MapperFeatures.java | 6 +++--- .../index/mapper/vectors/DenseVectorFieldMapper.java | 8 ++++---- 8 files changed, 12 insertions(+), 17 deletions(-) delete mode 100644 docs/changelog/135940.yaml diff --git a/docs/changelog/135940.yaml b/docs/changelog/135940.yaml deleted file mode 100644 index d754c98f320f5..0000000000000 --- a/docs/changelog/135940.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 135940 -summary: Enable directIO and bfloat16 for bbq and unquantized vector field types -area: Vector Search -type: feature -issues: [] diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml index 59b9840fdfd28..f9150cd49ac1a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml @@ -163,7 +163,7 @@ setup: --- "dense_vector docvalues with bfloat16": - requires: - cluster_features: [ "mapper.vectors.hnsw_bfloat16_on_disk_rescoring" ] + cluster_features: [ "mapper.vectors.generic_vector_format" ] reason: Needs bfloat16 support - do: indices.create: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 0a2c9eb31dc72..1111d1432f327 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -340,7 +340,7 @@ setup: --- "Test index configured rescore vector with on-disk rescoring": - requires: - cluster_features: ["mapper.vectors.hnsw_bfloat16_on_disk_rescoring"] + cluster_features: ["mapper.vectors.generic_vector_format"] reason: Needs on_disk_rescoring feature - skip: features: "headers" diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml index d873158e637c3..ff9273cacc0f4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -1,6 +1,6 @@ setup: - requires: - cluster_features: "mapper.vectors.hnsw_bfloat16_on_disk_rescoring" + cluster_features: "mapper.vectors.generic_vector_format" reason: 'bfloat16 needs to be supported' - do: indices.create: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml index 80fee2c53468f..8df025e621845 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -1,6 +1,6 @@ setup: - requires: - cluster_features: "mapper.vectors.hnsw_bfloat16_on_disk_rescoring" + cluster_features: "mapper.vectors.generic_vector_format" reason: 'bfloat16 needs to be supported' - do: indices.create: diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index c70d40f5cb03a..d2278bc0f30bf 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -25,7 +25,7 @@ public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { - public static final FeatureFlag ES93_VECTOR_FORMATS = new FeatureFlag("es93_vector_formats"); + public static final FeatureFlag GENERIC_VECTOR_FORMAT = new FeatureFlag("generic_vector_format"); static final String NAME = "ES93GenericFlatVectorsFormat"; static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi"; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index bb24068f499d1..b0cedb3e779e9 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -61,7 +61,7 @@ public class MapperFeatures implements FeatureSpecification { ); static final NodeFeature EXCLUDE_VECTORS_DOCVALUE_BUGFIX = new NodeFeature("mapper.exclude_vectors_docvalue_bugfix"); static final NodeFeature BASE64_DENSE_VECTORS = new NodeFeature("mapper.base64_dense_vectors"); - public static final NodeFeature HNSW_BFLOAT16_ON_DISK_RESCORING = new NodeFeature("mapper.vectors.hnsw_bfloat16_on_disk_rescoring"); + public static final NodeFeature GENERIC_VECTOR_FORMAT = new NodeFeature("mapper.vectors.generic_vector_format"); @Override public Set getTestFeatures() { @@ -106,9 +106,9 @@ public Set getTestFeatures() { EXCLUDE_VECTORS_DOCVALUE_BUGFIX, BASE64_DENSE_VECTORS ); - if (ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled()) { + if (ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled()) { features = new HashSet<>(features); - features.add(HNSW_BFLOAT16_ON_DISK_RESCORING); + features.add(GENERIC_VECTOR_FORMAT); } return features; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 5512aec27ed92..ada968f9756b5 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -242,7 +242,7 @@ public static class Builder extends FieldMapper.Builder { private final Parameter elementType = new Parameter<>("element_type", false, () -> ElementType.FLOAT, (n, c, o) -> { ElementType elementType = namesToElementType.get((String) o); if (elementType == null - || (elementType == ElementType.BFLOAT16 && ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled() == false)) { + || (elementType == ElementType.BFLOAT16 && ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled() == false)) { throw new MapperParsingException("invalid element_type [" + o + "]; available types are " + namesToElementType.keySet()); } return elementType; @@ -2149,7 +2149,7 @@ public static class HnswIndexOptions extends DenseVectorIndexOptions { @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType) { - if (ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled()) { + if (ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled()) { return new ES93HnswVectorsFormat(m, efConstruction, elementType); } else { if (elementType == ElementType.BIT) { @@ -2231,7 +2231,7 @@ public BBQHnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, Res @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; - return ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled() + return ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled() ? new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, elementType, onDiskRescore) : new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction); } @@ -2299,7 +2299,7 @@ static class BBQFlatIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16; - return ES93GenericFlatVectorsFormat.ES93_VECTOR_FORMATS.isEnabled() + return ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled() ? new ES93BinaryQuantizedVectorsFormat(elementType, false) : new ES818BinaryQuantizedVectorsFormat(); } From 66aa7674c7b36e2e27dee66de2aae58634bde0b9 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 6 Nov 2025 12:30:27 +0000 Subject: [PATCH 42/46] Remove obsoleted yaml tests --- .../search.vectors/40_knn_search_bfloat16.yml | 78 ------------------- 1 file changed, 78 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml index aa80764f14136..30dc27cfa8705 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml @@ -214,84 +214,6 @@ setup: - match: { error.root_cause.0.type: "illegal_argument_exception" } - match: { error.root_cause.0.reason: "cannot set [search_type] when using [knn] search, since the search type is determined automatically" } ---- -"kNN search in _knn_search endpoint": - - skip: - features: [ "allowed_warnings", "headers" ] - - do: - headers: - Content-Type: "application/vnd.elasticsearch+json;compatible-with=8" - Accept: "application/vnd.elasticsearch+json;compatible-with=8" - allowed_warnings: - - "The kNN search API has been replaced by the `knn` option in the search API." - knn_search: - index: test - body: - fields: [ "name" ] - knn: - field: vector - query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] - k: 2 - num_candidates: 3 - - - match: { hits.hits.0._id: "2" } - - match: { hits.hits.0.fields.name.0: "moose.jpg" } - - - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1.fields.name.0: "rabbit.jpg" } - ---- -"kNN search with filter in _knn_search endpoint": - - requires: - cluster_features: "gte_v8.2.0" - reason: 'kNN with filtering added in 8.2' - test_runner_features: [ "allowed_warnings", "headers" ] - - do: - headers: - Content-Type: "application/vnd.elasticsearch+json;compatible-with=8" - Accept: "application/vnd.elasticsearch+json;compatible-with=8" - allowed_warnings: - - "The kNN search API has been replaced by the `knn` option in the search API." - knn_search: - index: test - body: - fields: [ "name" ] - knn: - field: vector - query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] - k: 2 - num_candidates: 3 - filter: - term: - name: "rabbit.jpg" - - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0.fields.name.0: "rabbit.jpg" } - - - do: - headers: - Content-Type: "application/vnd.elasticsearch+json;compatible-with=8" - Accept: "application/vnd.elasticsearch+json;compatible-with=8" - allowed_warnings: - - "The kNN search API has been replaced by the `knn` option in the search API." - knn_search: - index: test - body: - fields: [ "name" ] - knn: - field: vector - query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] - k: 2 - num_candidates: 3 - filter: - - term: - name: "rabbit.jpg" - - term: - _id: 2 - - - match: { hits.total.value: 0 } - --- "Test nonexistent field is match none": - requires: From eed749e3a6d09e03e0731f5c526fa49e15cd471a Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Thu, 6 Nov 2025 14:02:49 +0000 Subject: [PATCH 43/46] Turn off yaml tests for now --- .../test/search.vectors/200_dense_vector_docvalue_fields.yml | 3 +++ .../test/search.vectors/40_knn_search_bfloat16.yml | 3 +++ .../test/search.vectors/41_knn_search_bbq_hnsw.yml | 3 +++ .../test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml | 3 +++ .../test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml | 3 +++ .../test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml | 3 +++ 6 files changed, 18 insertions(+) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml index f9150cd49ac1a..a6034ef20c0bd 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml @@ -165,6 +165,9 @@ setup: - requires: cluster_features: [ "mapper.vectors.generic_vector_format" ] reason: Needs bfloat16 support + - skip: + awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 + reason: Feature flag not enabled - do: indices.create: index: test-bfloat16 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml index 30dc27cfa8705..f1ad80522dcff 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml @@ -2,6 +2,9 @@ setup: - requires: cluster_features: ["mapper.vectors.hnsw_bfloat16_on_disk_rescoring"] reason: 'bfloat16 needs to be supported' + - skip: + awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 + reason: Feature flag not enabled - do: indices.create: index: test diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 1111d1432f327..8a9bed81ab4aa 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -342,6 +342,9 @@ setup: - requires: cluster_features: ["mapper.vectors.generic_vector_format"] reason: Needs on_disk_rescoring feature + - skip: + awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 + reason: Feature flag not enabled - skip: features: "headers" - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml index ff9273cacc0f4..7652052c989df 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -2,6 +2,9 @@ setup: - requires: cluster_features: "mapper.vectors.generic_vector_format" reason: 'bfloat16 needs to be supported' + - skip: + awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 + reason: Feature flag not enabled - do: indices.create: index: bbq_hnsw diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml index 8df025e621845..0bf57590a1ee0 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -2,6 +2,9 @@ setup: - requires: cluster_features: "mapper.vectors.generic_vector_format" reason: 'bfloat16 needs to be supported' + - skip: + awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 + reason: Feature flag not enabled - do: indices.create: index: bbq_flat diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml index aa59bceb00598..f4f09013f9751 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml @@ -2,6 +2,9 @@ setup: - requires: cluster_features: ["mapper.vectors.hnsw_bfloat16_on_disk_rescoring"] reason: 'bfloat16 needs to be supported' + - skip: + awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 + reason: Feature flag not enabled - skip: features: "headers" - do: From ab5ee251aa0c298a2ee2d980c54eb93af40a064a Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 7 Nov 2025 14:41:49 +0000 Subject: [PATCH 44/46] Remove index version (not needed yet) --- server/src/main/java/org/elasticsearch/index/IndexVersions.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index c16d307658ac4..1352c40890fc4 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -189,11 +189,9 @@ private static Version parseUnchecked(String version) { public static final IndexVersion BACKPORT_UPGRADE_TO_LUCENE_10_3_1 = def(9_039_0_01, Version.LUCENE_10_3_1); public static final IndexVersion KEYWORD_MULTI_FIELDS_NOT_STORED_WHEN_IGNORED = def(9_040_0_00, Version.LUCENE_10_3_0); public static final IndexVersion UPGRADE_TO_LUCENE_10_3_1 = def(9_041_0_00, Version.LUCENE_10_3_1); - public static final IndexVersion REENABLED_TIMESTAMP_DOC_VALUES_SPARSE_INDEX = def(9_042_0_00, Version.LUCENE_10_3_1); public static final IndexVersion SKIPPERS_ENABLED_BY_DEFAULT = def(9_043_0_00, Version.LUCENE_10_3_1); public static final IndexVersion TIME_SERIES_USE_SYNTHETIC_ID = def(9_044_0_00, Version.LUCENE_10_3_1); - public static final IndexVersion BFLOAT16_HNSW_SUPPORT = def(9_045_0_00, Version.LUCENE_10_3_1); /* * STOP! READ THIS FIRST! No, really, From 2a58d2f335e252bc68ed1928bd3c87d4eb4194d9 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Fri, 7 Nov 2025 16:51:29 +0000 Subject: [PATCH 45/46] Don't accept the rescore option if the flag is not enabled --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index ada968f9756b5..2e83828944a67 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1640,7 +1640,9 @@ public boolean supportsDimension(int dims) { public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - Object onDiskRescoreNode = indexOptionsMap.remove("on_disk_rescore"); + Object onDiskRescoreNode = ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled() + ? indexOptionsMap.remove("on_disk_rescore") + : false; int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); @@ -1731,7 +1733,9 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map Date: Fri, 7 Nov 2025 17:10:49 +0000 Subject: [PATCH 46/46] Add generic vector flag to the yaml runners --- .../test/rest/yaml/CcsCommonYamlTestSuiteIT.java | 3 ++- .../test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java | 1 + .../smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java | 1 + .../org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java | 1 + .../search.vectors/200_dense_vector_docvalue_fields.yml | 5 +---- .../test/search.vectors/40_knn_search_bfloat16.yml | 7 ++----- .../test/search.vectors/41_knn_search_bbq_hnsw.yml | 7 ++----- .../search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml | 7 ++----- .../search.vectors/42_knn_search_bbq_flat_bfloat16.yml | 7 ++----- .../test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml | 7 ++----- .../java/org/elasticsearch/test/cluster/FeatureFlag.java | 3 ++- 11 files changed, 18 insertions(+), 31 deletions(-) diff --git a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java index 7b7fa6b4ab6d1..f517e52ce423a 100644 --- a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java +++ b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java @@ -96,7 +96,8 @@ public class CcsCommonYamlTestSuiteIT extends ESClientYamlSuiteTestCase { // geohex_grid requires gold license .setting("xpack.license.self_generated.type", "trial") .feature(FeatureFlag.TIME_SERIES_MODE) - .feature(FeatureFlag.SYNTHETIC_VECTORS); + .feature(FeatureFlag.SYNTHETIC_VECTORS) + .feature(FeatureFlag.GENERIC_VECTOR_FORMAT); private static ElasticsearchCluster remoteCluster = ElasticsearchCluster.local() .name(REMOTE_CLUSTER_NAME) diff --git a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java index e5362c31f32f9..7d1ed9d92238a 100644 --- a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java +++ b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java @@ -96,6 +96,7 @@ public class RcsCcsCommonYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .setting("xpack.security.remote_cluster_client.ssl.enabled", "false") .feature(FeatureFlag.TIME_SERIES_MODE) .feature(FeatureFlag.SYNTHETIC_VECTORS) + .feature(FeatureFlag.GENERIC_VECTOR_FORMAT) .user("test_admin", "x-pack-test-password"); private static ElasticsearchCluster fulfillingCluster = ElasticsearchCluster.local() diff --git a/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java b/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java index 529d7a7155264..8af1760f2ebd3 100644 --- a/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java +++ b/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java @@ -38,6 +38,7 @@ public class SmokeTestMultiNodeClientYamlTestSuiteIT extends ESClientYamlSuiteTe .feature(FeatureFlag.DOC_VALUES_SKIPPER) .feature(FeatureFlag.SYNTHETIC_VECTORS) .feature(FeatureFlag.RANDOM_SAMPLING) + .feature(FeatureFlag.GENERIC_VECTOR_FORMAT) .build(); public SmokeTestMultiNodeClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java b/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java index 2cebc6a743703..45b049784c380 100644 --- a/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java +++ b/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java @@ -38,6 +38,7 @@ public class ClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .feature(FeatureFlag.DOC_VALUES_SKIPPER) .feature(FeatureFlag.SYNTHETIC_VECTORS) .feature(FeatureFlag.RANDOM_SAMPLING) + .feature(FeatureFlag.GENERIC_VECTOR_FORMAT) .build(); public ClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml index a6034ef20c0bd..6edbcf5ef28ff 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml @@ -164,10 +164,7 @@ setup: "dense_vector docvalues with bfloat16": - requires: cluster_features: [ "mapper.vectors.generic_vector_format" ] - reason: Needs bfloat16 support - - skip: - awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 - reason: Feature flag not enabled + reason: Needs generic vector support - do: indices.create: index: test-bfloat16 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml index f1ad80522dcff..51adafc624469 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml @@ -1,10 +1,7 @@ setup: - requires: - cluster_features: ["mapper.vectors.hnsw_bfloat16_on_disk_rescoring"] - reason: 'bfloat16 needs to be supported' - - skip: - awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 - reason: Feature flag not enabled + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support - do: indices.create: index: test diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 8a9bed81ab4aa..ed15ed4d09806 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -340,11 +340,8 @@ setup: --- "Test index configured rescore vector with on-disk rescoring": - requires: - cluster_features: ["mapper.vectors.generic_vector_format"] - reason: Needs on_disk_rescoring feature - - skip: - awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 - reason: Feature flag not enabled + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support - skip: features: "headers" - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml index 7652052c989df..358089c5342ad 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -1,10 +1,7 @@ setup: - requires: - cluster_features: "mapper.vectors.generic_vector_format" - reason: 'bfloat16 needs to be supported' - - skip: - awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 - reason: Feature flag not enabled + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support - do: indices.create: index: bbq_hnsw diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml index 0bf57590a1ee0..6cbeb9ecfd189 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -1,10 +1,7 @@ setup: - requires: - cluster_features: "mapper.vectors.generic_vector_format" - reason: 'bfloat16 needs to be supported' - - skip: - awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 - reason: Feature flag not enabled + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support - do: indices.create: index: bbq_flat diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml index f4f09013f9751..b47d337120c54 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml @@ -1,10 +1,7 @@ setup: - requires: - cluster_features: ["mapper.vectors.hnsw_bfloat16_on_disk_rescoring"] - reason: 'bfloat16 needs to be supported' - - skip: - awaits_fix: https://github.com/elastic/elasticsearch/issues/131109 - reason: Feature flag not enabled + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support - skip: features: "headers" - do: diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index a1a0486aecfc8..527d961197e8a 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -27,7 +27,8 @@ public enum FeatureFlag { null ), RANDOM_SAMPLING("es.random_sampling_feature_flag_enabled=true", Version.fromString("9.2.0"), null), - INFERENCE_API_CCM("es.inference_api_ccm_feature_flag_enabled=true", Version.fromString("9.3.0"), null); + INFERENCE_API_CCM("es.inference_api_ccm_feature_flag_enabled=true", Version.fromString("9.3.0"), null), + GENERIC_VECTOR_FORMAT("es.generic_vector_format_feature_flag_enabled=true", Version.fromString("9.3.0"), null); public final String systemProperty; public final Version from;