diff --git a/docs/changelog/129046.yaml b/docs/changelog/129046.yaml new file mode 100644 index 0000000000000..008a1f61020a2 --- /dev/null +++ b/docs/changelog/129046.yaml @@ -0,0 +1,5 @@ +pr: 129046 +summary: Add Lucene improvements for HNSW merging heap usage +area: Search +type: enhancement +issues: [] diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index cd418d21e05c8..3182c501613e7 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -456,6 +456,7 @@ org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsFormat, org.elasticsearch.index.codec.vectors.IVFVectorsFormat; provides org.apache.lucene.codecs.Codec diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Codec.java index 04428d5b37fba..3d81536f12323 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Codec.java @@ -16,10 +16,10 @@ import org.apache.lucene.codecs.PostingsFormat; import org.apache.lucene.codecs.StoredFieldsFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; +import org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; /** @@ -68,7 +68,7 @@ public Elasticsearch900Codec(Zstd814StoredFieldsFormat.Mode mode) { this.storedFieldsFormat = mode.getFormat(); this.defaultPostingsFormat = new Lucene912PostingsFormat(); this.defaultDVFormat = new Lucene90DocValuesFormat(); - this.defaultKnnVectorsFormat = new Lucene99HnswVectorsFormat(); + this.defaultKnnVectorsFormat = new ES910HnswVectorsFormat(); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Lucene101Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Lucene101Codec.java index 3edd55d8f8de7..bb061e459953a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Lucene101Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Lucene101Codec.java @@ -16,10 +16,10 @@ import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; import org.elasticsearch.index.codec.perfield.XPerFieldDocValuesFormat; +import org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; /** @@ -70,7 +70,7 @@ public Elasticsearch900Lucene101Codec(Zstd814StoredFieldsFormat.Mode mode) { this.storedFieldsFormat = mode.getFormat(); this.defaultPostingsFormat = DEFAULT_POSTINGS_FORMAT; this.defaultDVFormat = new Lucene90DocValuesFormat(); - this.defaultKnnVectorsFormat = new Lucene99HnswVectorsFormat(); + this.defaultKnnVectorsFormat = new ES910HnswVectorsFormat(); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java b/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java index ecb0d6d5eb3ca..3ce00d38a5b2f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java +++ b/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java @@ -13,7 +13,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.PostingsFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.IndexSettings; @@ -21,6 +20,7 @@ import org.elasticsearch.index.codec.bloomfilter.ES87BloomFilterPostingsFormat; import org.elasticsearch.index.codec.postings.ES812PostingsFormat; import org.elasticsearch.index.codec.tsdb.es819.ES819TSDBDocValuesFormat; +import org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsFormat; import org.elasticsearch.index.mapper.CompletionFieldMapper; import org.elasticsearch.index.mapper.IdFieldMapper; import org.elasticsearch.index.mapper.Mapper; @@ -34,7 +34,7 @@ public class PerFieldFormatSupplier { private static final DocValuesFormat docValuesFormat = new Lucene90DocValuesFormat(); - private static final KnnVectorsFormat knnVectorsFormat = new Lucene99HnswVectorsFormat(); + private static final KnnVectorsFormat knnVectorsFormat = new ES910HnswVectorsFormat(); private static final ES819TSDBDocValuesFormat tsdbDocValuesFormat = new ES819TSDBDocValuesFormat(); private static final ES812PostingsFormat es812PostingsFormat = new ES812PostingsFormat(); private static final PostingsFormat completionPostingsFormat = PostingsFormat.forName("Completion101"); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java index 6bb32d8e1ef52..3a733f2dc242e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java @@ -14,9 +14,9 @@ import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsWriter; import java.io.IOException; @@ -61,7 +61,7 @@ public ES814HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth, Float c @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null); + return new ES910HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state)); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java index 186dfcbeb5d52..6febadd35e33d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java @@ -14,9 +14,9 @@ import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsWriter; import java.io.IOException; @@ -56,7 +56,7 @@ public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) { @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null); + return new ES910HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state)); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java index 56942017c3cef..67235dfa56ef7 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java @@ -25,11 +25,11 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.hnsw.HnswGraph; +import org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsWriter; import java.io.IOException; import java.util.concurrent.ExecutorService; @@ -119,7 +119,7 @@ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int num @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + return new ES910HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state)); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/ES910HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/ES910HnswVectorsFormat.java new file mode 100644 index 0000000000000..1bf0bcf0e3775 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/ES910HnswVectorsFormat.java @@ -0,0 +1,125 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10.3.0 + */ +public class ES910HnswVectorsFormat extends KnnVectorsFormat { + + static final String NAME = "ES910HnswVectorsFormat"; + + static final String META_CODEC_NAME = "Lucene99HnswVectorsFormatMeta"; + static final String VECTOR_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex"; + static final String META_EXTENSION = "vem"; + static final String VECTOR_INDEX_EXTENSION = "vex"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, and merging vectors on disk. */ + private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + /** Constructs a format using default graph construction parameters */ + public ES910HnswVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public ES910HnswVectorsFormat(int maxConn, int beamWidth) { + super(ES910HnswVectorsFormat.NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES910HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state)); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return Lucene99HnswVectorsFormat.DEFAULT_MAX_DIMENSIONS; + } + + @Override + public String toString() { + return "ES910HnswReducedHeapVectorsFormat(name=ES910HnswReducedHeapVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/ES910HnswVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/ES910HnswVectorsWriter.java new file mode 100644 index 0000000000000..4b816562be0d1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/ES910HnswVectorsWriter.java @@ -0,0 +1,615 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es910; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.SuppressForbidden; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicWriter; +import org.elasticsearch.index.codec.vectors.es910.hnsw.HnswGraphBuilder; +import org.elasticsearch.index.codec.vectors.es910.hnsw.IncrementalHnswGraphMerger; +import org.elasticsearch.index.codec.vectors.es910.hnsw.NeighborArray; +import org.elasticsearch.index.codec.vectors.es910.hnsw.OnHeapHnswGraph; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; +import static org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10.3.0 + */ +@SuppressForbidden(reason = "Lucene classes") +public class ES910HnswVectorsWriter extends KnnVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ES910HnswVectorsWriter.class); + private final SegmentWriteState segmentWriteState; + private final IndexOutput meta, vectorIndex; + private final int M; + private final int beamWidth; + private final FlatVectorsWriter flatVectorWriter; + + private final List> fields = new ArrayList<>(); + private boolean finished; + + public ES910HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth, FlatVectorsWriter flatVectorWriter) throws IOException { + this.M = M; + this.flatVectorWriter = flatVectorWriter; + this.beamWidth = beamWidth; + segmentWriteState = state; + + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES910HnswVectorsFormat.META_EXTENSION + ); + + String indexDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES910HnswVectorsFormat.VECTOR_INDEX_EXTENSION + ); + + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorIndex = state.directory.createOutput(indexDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES910HnswVectorsFormat.META_CODEC_NAME, + Lucene99HnswVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + vectorIndex, + ES910HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME, + Lucene99HnswVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + ES910HnswVectorsWriter.FieldWriter newField = ES910HnswVectorsWriter.FieldWriter.create( + flatVectorWriter.getFlatVectorScorer(), + flatVectorWriter.addField(fieldInfo), + fieldInfo, + M, + beamWidth, + segmentWriteState.infoStream + ); + fields.add(newField); + return newField; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + flatVectorWriter.flush(maxDoc, sortMap); + for (ES910HnswVectorsWriter.FieldWriter field : fields) { + if (sortMap == null) { + writeField(field); + } else { + writeSortingField(field, sortMap); + } + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + flatVectorWriter.finish(); + + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorIndex != null) { + CodecUtil.writeFooter(vectorIndex); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (ES910HnswVectorsWriter.FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + private void writeField(ES910HnswVectorsWriter.FieldWriter fieldData) throws IOException { + // write graph + long vectorIndexOffset = vectorIndex.getFilePointer(); + OnHeapHnswGraph graph = fieldData.getGraph(); + int[][] graphLevelNodeOffsets = writeGraph(graph); + long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; + + writeMeta( + fieldData.fieldInfo, + vectorIndexOffset, + vectorIndexLength, + fieldData.getDocsWithFieldSet().cardinality(), + graph, + graphLevelNodeOffsets + ); + } + + private void writeSortingField(ES910HnswVectorsWriter.FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException { + final int[] ordMap = new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + final int[] oldOrdMap = new int[fieldData.getDocsWithFieldSet().cardinality()]; // old ord to new ord + + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, oldOrdMap, ordMap, null); + // write graph + long vectorIndexOffset = vectorIndex.getFilePointer(); + OnHeapHnswGraph graph = fieldData.getGraph(); + int[][] graphLevelNodeOffsets = graph == null ? new int[0][] : new int[graph.numLevels()][]; + HnswGraph mockGraph = reconstructAndWriteGraph(graph, ordMap, oldOrdMap, graphLevelNodeOffsets); + long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; + + writeMeta( + fieldData.fieldInfo, + vectorIndexOffset, + vectorIndexLength, + fieldData.getDocsWithFieldSet().cardinality(), + mockGraph, + graphLevelNodeOffsets + ); + } + + /** + * Reconstructs the graph given the old and new node ids. + * + *

Additionally, the graph node connections are written to the vectorIndex. + * + * @param graph The current on heap graph + * @param newToOldMap the new node ids indexed to the old node ids + * @param oldToNewMap the old node ids indexed to the new node ids + * @param levelNodeOffsets where to place the new offsets for the nodes in the vector index. + * @return The graph + * @throws IOException if writing to vectorIndex fails + */ + private HnswGraph reconstructAndWriteGraph(OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap, int[][] levelNodeOffsets) + throws IOException { + if (graph == null) return null; + + List nodesByLevel = new ArrayList<>(graph.numLevels()); + nodesByLevel.add(null); + + int maxOrd = graph.size(); + int[] scratch = new int[graph.maxConn() * 2]; + HnswGraph.NodesIterator nodesOnLevel0 = graph.getNodesOnLevel(0); + levelNodeOffsets[0] = new int[nodesOnLevel0.size()]; + while (nodesOnLevel0.hasNext()) { + int node = nodesOnLevel0.nextInt(); + NeighborArray neighbors = graph.getNeighbors(0, newToOldMap[node]); + long offset = vectorIndex.getFilePointer(); + reconstructAndWriteNeighbours(neighbors, oldToNewMap, scratch, maxOrd); + levelNodeOffsets[0][node] = Math.toIntExact(vectorIndex.getFilePointer() - offset); + } + + for (int level = 1; level < graph.numLevels(); level++) { + HnswGraph.NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); + int[] newNodes = new int[nodesOnLevel.size()]; + for (int n = 0; nodesOnLevel.hasNext(); n++) { + newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()]; + } + Arrays.sort(newNodes); + nodesByLevel.add(newNodes); + levelNodeOffsets[level] = new int[newNodes.length]; + int nodeOffsetIndex = 0; + for (int node : newNodes) { + NeighborArray neighbors = graph.getNeighbors(level, newToOldMap[node]); + long offset = vectorIndex.getFilePointer(); + reconstructAndWriteNeighbours(neighbors, oldToNewMap, scratch, maxOrd); + levelNodeOffsets[level][nodeOffsetIndex++] = Math.toIntExact(vectorIndex.getFilePointer() - offset); + } + } + return new HnswGraph() { + @Override + public int nextNeighbor() { + throw new UnsupportedOperationException("Not supported on a mock graph"); + } + + @Override + public void seek(int level, int target) { + throw new UnsupportedOperationException("Not supported on a mock graph"); + } + + @Override + public int size() { + return graph.size(); + } + + @Override + public int numLevels() { + return graph.numLevels(); + } + + @Override + public int maxConn() { + return graph.maxConn(); + } + + @Override + public int entryNode() { + throw new UnsupportedOperationException("Not supported on a mock graph"); + } + + @Override + public int neighborCount() { + throw new UnsupportedOperationException("Not supported on a mock graph"); + } + + @Override + public NodesIterator getNodesOnLevel(int level) { + if (level == 0) { + return graph.getNodesOnLevel(0); + } else { + return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length); + } + } + }; + } + + private void reconstructAndWriteNeighbours(NeighborArray neighbors, int[] oldToNewMap, int[] scratch, int maxOrd) throws IOException { + int size = neighbors.size(); + // Destructively modify; it's ok we are discarding it after this + int[] nnodes = neighbors.nodes(); + for (int i = 0; i < size; i++) { + nnodes[i] = oldToNewMap[nnodes[i]]; + } + Arrays.sort(nnodes, 0, size); + int actualSize = 0; + if (size > 0) { + scratch[0] = nnodes[0]; + actualSize = 1; + } + // Now that we have sorted, do delta encoding to minimize the required bits to store the + // information + for (int i = 1; i < size; i++) { + assert nnodes[i] < maxOrd : "node too large: " + nnodes[i] + ">=" + maxOrd; + if (nnodes[i - 1] == nnodes[i]) { + continue; + } + scratch[actualSize++] = nnodes[i] - nnodes[i - 1]; + } + // Write the size after duplicates are removed + vectorIndex.writeVInt(actualSize); + for (int i = 0; i < actualSize; i++) { + vectorIndex.writeVInt(scratch[i]); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + CloseableRandomVectorScorerSupplier scorerSupplier = flatVectorWriter.mergeOneFieldToIndex(fieldInfo, mergeState); + boolean success = false; + try { + long vectorIndexOffset = vectorIndex.getFilePointer(); + // build the graph using the temporary vector data + // we use Lucene99HnswVectorsReader.DenseOffHeapVectorValues for the graph construction + // doesn't need to know docIds + // TODO: separate random access vector values from DocIdSetIterator? + OnHeapHnswGraph graph = null; + int[][] vectorIndexNodeOffsets = null; + if (scorerSupplier.totalVectorCount() > 0) { + // build graph + IncrementalHnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier); + for (int i = 0; i < mergeState.liveDocs.length; i++) { + if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) { + merger.addReader(mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]); + } + } + KnnVectorValues mergedVectorValues = null; + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + case FLOAT32 -> mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + } + graph = merger.merge(mergedVectorValues, segmentWriteState.infoStream, scorerSupplier.totalVectorCount()); + vectorIndexNodeOffsets = writeGraph(graph); + } + long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; + writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, scorerSupplier.totalVectorCount(), graph, vectorIndexNodeOffsets); + success = true; + } finally { + if (success) { + IOUtils.close(scorerSupplier); + } else { + IOUtils.closeWhileHandlingException(scorerSupplier); + } + } + } + + /** + * @param graph Write the graph in a compressed format + * @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets. + * @throws IOException if writing to vectorIndex fails + */ + private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException { + if (graph == null) return new int[0][0]; + // write vectors' neighbours on each level into the vectorIndex file + int countOnLevel0 = graph.size(); + int[][] offsets = new int[graph.numLevels()][]; + int[] scratch = new int[graph.maxConn() * 2]; + for (int level = 0; level < graph.numLevels(); level++) { + int[] sortedNodes = HnswGraph.NodesIterator.getSortedNodes(graph.getNodesOnLevel(level)); + offsets[level] = new int[sortedNodes.length]; + int nodeOffsetId = 0; + for (int node : sortedNodes) { + NeighborArray neighbors = graph.getNeighbors(level, node); + int size = neighbors.size(); + // Write size in VInt as the neighbors list is typically small + long offsetStart = vectorIndex.getFilePointer(); + int[] nnodes = neighbors.nodes(); + Arrays.sort(nnodes, 0, size); + // Now that we have sorted, do delta encoding to minimize the required bits to store the + // information + int actualSize = 0; + if (size > 0) { + scratch[0] = nnodes[0]; + actualSize = 1; + } + for (int i = 1; i < size; i++) { + assert nnodes[i] < countOnLevel0 : "node too large: " + nnodes[i] + ">=" + countOnLevel0; + if (nnodes[i - 1] == nnodes[i]) { + continue; + } + scratch[actualSize++] = nnodes[i] - nnodes[i - 1]; + } + // Write the size after duplicates are removed + vectorIndex.writeVInt(actualSize); + for (int i = 0; i < actualSize; i++) { + vectorIndex.writeVInt(scratch[i]); + } + offsets[level][nodeOffsetId++] = Math.toIntExact(vectorIndex.getFilePointer() - offsetStart); + } + } + return offsets; + } + + private void writeMeta( + FieldInfo field, + long vectorIndexOffset, + long vectorIndexLength, + int count, + HnswGraph graph, + int[][] graphLevelNodeOffsets + ) throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); + meta.writeVLong(vectorIndexOffset); + meta.writeVLong(vectorIndexLength); + meta.writeVInt(field.getVectorDimension()); + meta.writeInt(count); + meta.writeVInt(M); + // write graph nodes on each level + if (graph == null) { + meta.writeVInt(0); + } else { + meta.writeVInt(graph.numLevels()); + long valueCount = 0; + for (int level = 0; level < graph.numLevels(); level++) { + HnswGraph.NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); + valueCount += nodesOnLevel.size(); + if (level > 0) { + int[] nol = new int[nodesOnLevel.size()]; + int numberConsumed = nodesOnLevel.consume(nol); + Arrays.sort(nol); + assert numberConsumed == nodesOnLevel.size(); + meta.writeVInt(nol.length); // number of nodes on a level + for (int i = nodesOnLevel.size() - 1; i > 0; --i) { + nol[i] -= nol[i - 1]; + } + for (int n : nol) { + assert n >= 0 : "delta encoding for nodes failed; expected nodes to be sorted"; + meta.writeVInt(n); + } + } else { + assert nodesOnLevel.size() == count : "Level 0 expects to have all nodes"; + } + } + long start = vectorIndex.getFilePointer(); + meta.writeLong(start); + meta.writeVInt(DIRECT_MONOTONIC_BLOCK_SHIFT); + final DirectMonotonicWriter memoryOffsetsWriter = DirectMonotonicWriter.getInstance( + meta, + vectorIndex, + valueCount, + DIRECT_MONOTONIC_BLOCK_SHIFT + ); + long cumulativeOffsetSum = 0; + for (int[] levelOffsets : graphLevelNodeOffsets) { + for (int v : levelOffsets) { + memoryOffsetsWriter.add(cumulativeOffsetSum); + cumulativeOffsetSum += v; + } + } + memoryOffsetsWriter.finish(); + meta.writeLong(vectorIndex.getFilePointer() - start); + } + } + + private IncrementalHnswGraphMerger createGraphMerger(FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) { + return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth); + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorIndex, flatVectorWriter); + } + + static int distFuncToOrd(VectorSimilarityFunction func) { + for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) { + if (SIMILARITY_FUNCTIONS.get(i).equals(func)) { + return (byte) i; + } + } + throw new IllegalArgumentException("invalid distance function: " + func); + } + + private static class FieldWriter extends KnnFieldVectorsWriter { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES910HnswVectorsWriter.FieldWriter.class); + + private final FieldInfo fieldInfo; + private final HnswGraphBuilder hnswGraphBuilder; + private int lastDocID = -1; + private int node = 0; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private UpdateableRandomVectorScorer scorer; + + @SuppressWarnings("unchecked") + static ES910HnswVectorsWriter.FieldWriter create( + FlatVectorsScorer scorer, + FlatFieldVectorsWriter flatFieldVectorsWriter, + FieldInfo fieldInfo, + int M, + int beamWidth, + InfoStream infoStream + ) throws IOException { + return switch (fieldInfo.getVectorEncoding()) { + case BYTE -> new ES910HnswVectorsWriter.FieldWriter<>( + scorer, + (FlatFieldVectorsWriter) flatFieldVectorsWriter, + fieldInfo, + M, + beamWidth, + infoStream + ); + case FLOAT32 -> new ES910HnswVectorsWriter.FieldWriter<>( + scorer, + (FlatFieldVectorsWriter) flatFieldVectorsWriter, + fieldInfo, + M, + beamWidth, + infoStream + ); + }; + } + + @SuppressWarnings("unchecked") + FieldWriter( + FlatVectorsScorer scorer, + FlatFieldVectorsWriter flatFieldVectorsWriter, + FieldInfo fieldInfo, + int M, + int beamWidth, + InfoStream infoStream + ) throws IOException { + this.fieldInfo = fieldInfo; + RandomVectorScorerSupplier scorerSupplier = switch (fieldInfo.getVectorEncoding()) { + case BYTE -> scorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + ByteVectorValues.fromBytes((List) flatFieldVectorsWriter.getVectors(), fieldInfo.getVectorDimension()) + ); + case FLOAT32 -> scorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + FloatVectorValues.fromFloats((List) flatFieldVectorsWriter.getVectors(), fieldInfo.getVectorDimension()) + ); + }; + this.scorer = scorerSupplier.scorer(); + hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); + hnswGraphBuilder.setInfoStream(infoStream); + this.flatFieldVectorsWriter = Objects.requireNonNull(flatFieldVectorsWriter); + } + + @Override + public void addValue(int docID, T vectorValue) throws IOException { + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)" + ); + } + flatFieldVectorsWriter.addValue(docID, vectorValue); + scorer.setScoringOrdinal(node); + hnswGraphBuilder.addGraphNode(node, scorer); + node++; + lastDocID = docID; + } + + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public T copyValue(T vectorValue) { + throw new UnsupportedOperationException(); + } + + OnHeapHnswGraph getGraph() throws IOException { + assert flatFieldVectorsWriter.isFinished(); + if (node > 0) { + return hnswGraphBuilder.getCompletedGraph(); + } else { + return null; + } + } + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + flatFieldVectorsWriter.ramBytesUsed() + hnswGraphBuilder.getGraph().ramBytesUsed(); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/AbstractHnswGraphSearcher.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/AbstractHnswGraphSearcher.java new file mode 100644 index 0000000000000..ec8343c6ea334 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/AbstractHnswGraphSearcher.java @@ -0,0 +1,81 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es910.hnsw; + +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.io.IOException; + +/** + * AbstractHnswGraphSearcher is the base class for HnswGraphSearcher implementations. + * + * @lucene.experimental + */ +abstract class AbstractHnswGraphSearcher { + + static final int UNK_EP = -1; + + /** + * Search a given level of the graph starting at the given entry points. + * + * @param results the collector to collect the results + * @param scorer the scorer to compare the query with the nodes + * @param level the level of the graph to search + * @param eps the entry points to start the search from + * @param graph the HNSWGraph + * @param acceptOrds the ordinals to accept for the results + */ + abstract void searchLevel(KnnCollector results, RandomVectorScorer scorer, int level, int[] eps, HnswGraph graph, Bits acceptOrds) + throws IOException; + + /** + * Function to find the best entry point from which to search the zeroth graph layer. + * + * @param scorer the scorer to compare the query with the nodes + * @param graph the HNSWGraph + * @param collector the knn result collector + * @return the best entry point, `-1` indicates graph entry node not set, or visitation limit + * exceeded + * @throws IOException When accessing the vectors or graph fails + */ + abstract int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollector collector) throws IOException; + + /** + * Search the graph for the given scorer. Gathering results in the provided collector that pass + * the provided acceptOrds. + * + * @param results the collector to collect the results + * @param scorer the scorer to compare the query with the nodes + * @param graph the HNSWGraph + * @param acceptOrds the ordinals to accept for the results + * @throws IOException When accessing the vectors or graph fails + */ + public void search(KnnCollector results, RandomVectorScorer scorer, HnswGraph graph, Bits acceptOrds) throws IOException { + int[] eps = findBestEntryPoint(scorer, graph, results); + assert eps != null && eps.length > 0; + if (eps[0] == UNK_EP) { + return; + } + searchLevel(results, scorer, 0, eps, graph, acceptOrds); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/HnswGraphBuilder.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/HnswGraphBuilder.java new file mode 100644 index 0000000000000..0e73e12450a39 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/HnswGraphBuilder.java @@ -0,0 +1,525 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910.hnsw; + +import org.apache.lucene.internal.hppc.IntHashSet; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.NeighborQueue; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; +import java.util.SplittableRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; + +import static java.lang.Math.log; + +/** + * Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the + * hyper-parameters. + */ +public class HnswGraphBuilder { + + /** Default number of maximum connections per node */ + public static final int DEFAULT_MAX_CONN = 16; + + /** Default random seed for level generation * */ + private static final long DEFAULT_RAND_SEED = 42; + + /** A name for the HNSW component for the info-stream * */ + public static final String HNSW_COMPONENT = "HNSW"; + + /** Random seed for level generation; public to expose for testing * */ + @SuppressWarnings("NonFinalStaticField") + public static long randSeed = DEFAULT_RAND_SEED; + + private final int M; // max number of connections on upper layers + private final double ml; + + private final SplittableRandom random; + protected final RandomVectorScorerSupplier scorerSupplier; + private final HnswGraphSearcher graphSearcher; + private final GraphBuilderKnnCollector entryCandidates; // for upper levels of graph search + private final GraphBuilderKnnCollector beamCandidates; // for levels of graph where we add the node + private final GraphBuilderKnnCollector beamCandidates0; + + protected final OnHeapHnswGraph hnsw; + protected final HnswLock hnswLock; + + protected InfoStream infoStream = InfoStream.getDefault(); + protected boolean frozen; + + public static HnswGraphBuilder create(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) throws IOException { + return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, -1); + } + + public static HnswGraphBuilder create(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) + throws IOException { + return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed, graphSize); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param scorerSupplier a supplier to create vector scorer from ordinals. + * @param M – graph fanout parameter used to calculate the maximum number of connections a node + * can have – M on upper layers, and M * 2 on the lowest level. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + * @param seed the seed for a random number generator used during graph construction. Provide this + * to ensure repeatable construction. + * @param graphSize size of graph, if unknown, pass in -1 + */ + protected HnswGraphBuilder(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize) + throws IOException { + this(scorerSupplier, beamWidth, seed, new OnHeapHnswGraph(M, graphSize)); + } + + protected HnswGraphBuilder(RandomVectorScorerSupplier scorerSupplier, int beamWidth, long seed, OnHeapHnswGraph hnsw) + throws IOException { + this( + scorerSupplier, + beamWidth, + seed, + hnsw, + null, + new HnswGraphSearcher(new NeighborQueue(beamWidth, true), new FixedBitSet(hnsw.size())) + ); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param scorerSupplier a supplier to create vector scorer from ordinals. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + * @param seed the seed for a random number generator used during graph construction. Provide this + * to ensure repeatable construction. + * @param hnsw the graph to build, can be previously initialized + */ + protected HnswGraphBuilder( + RandomVectorScorerSupplier scorerSupplier, + int beamWidth, + long seed, + OnHeapHnswGraph hnsw, + HnswLock hnswLock, + HnswGraphSearcher graphSearcher + ) throws IOException { + if (hnsw.maxConn() <= 0) { + throw new IllegalArgumentException("M (max connections) must be positive"); + } + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } + this.M = hnsw.maxConn(); + this.scorerSupplier = Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null"); + // normalization factor for level generation; currently not configurable + this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); + this.random = new SplittableRandom(seed); + this.hnsw = hnsw; + this.hnswLock = hnswLock; + this.graphSearcher = graphSearcher; + entryCandidates = new GraphBuilderKnnCollector(1); + beamCandidates = new GraphBuilderKnnCollector(beamWidth); + beamCandidates0 = new GraphBuilderKnnCollector(Math.min(beamWidth / 2, M * 3)); + } + + public OnHeapHnswGraph build(int maxOrd) throws IOException { + if (frozen) { + throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated"); + } + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors"); + } + addVectors(maxOrd); + return getCompletedGraph(); + } + + public void setInfoStream(InfoStream infoStream) { + this.infoStream = infoStream; + } + + public OnHeapHnswGraph getCompletedGraph() throws IOException { + if (frozen == false) { + finish(); + } + return getGraph(); + } + + public OnHeapHnswGraph getGraph() { + return hnsw; + } + + /** add vectors in range [minOrd, maxOrd) */ + protected void addVectors(int minOrd, int maxOrd) throws IOException { + if (frozen) { + throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated"); + } + long start = System.nanoTime(), t = start; + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")"); + } + UpdateableRandomVectorScorer scorer = scorerSupplier.scorer(); + for (int node = minOrd; node < maxOrd; node++) { + scorer.setScoringOrdinal(node); + addGraphNode(node, scorer); + if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { + t = printGraphBuildStatus(node, start, t); + } + } + } + + private void addVectors(int maxOrd) throws IOException { + addVectors(0, maxOrd); + } + + public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException { + addGraphNodeInternal(node, scorer, null); + } + + private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer, IntHashSet eps0) throws IOException { + if (frozen) { + throw new IllegalStateException("Graph builder is already frozen"); + } + final int nodeLevel = getRandomGraphLevel(ml, random); + // first add nodes to all levels + for (int level = nodeLevel; level >= 0; level--) { + hnsw.addNode(level, node); + } + // then promote itself as entry node if entry node is not set + if (hnsw.trySetNewEntryNode(node, nodeLevel)) { + return; + } + // if the entry node is already set, then we have to do all connections first before we can + // promote ourselves as entry node + + int lowestUnsetLevel = 0; + int curMaxLevel; + do { + curMaxLevel = hnsw.numLevels() - 1; + // NOTE: the entry node and max level may not be paired, but because we get the level first + // we ensure that the entry node we get later will always exist on the curMaxLevel + int[] eps = new int[] { hnsw.entryNode() }; + + // we first do the search from top to bottom + // for levels > nodeLevel search with topk = 1 + GraphBuilderKnnCollector candidates = entryCandidates; + for (int level = curMaxLevel; level > nodeLevel; level--) { + candidates.clear(); + graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null); + eps[0] = candidates.popNode(); + } + + // for levels <= nodeLevel search with topk = beamWidth, and add connections + candidates = beamCandidates; + NeighborArray[] scratchPerLevel = new NeighborArray[Math.min(nodeLevel, curMaxLevel) - lowestUnsetLevel + 1]; + for (int i = scratchPerLevel.length - 1; i >= 0; i--) { + int level = i + lowestUnsetLevel; + candidates.clear(); + if (level == 0 && eps0 != null && eps0.size() > 0) { + eps = eps0.toArray(); + candidates = beamCandidates0; + } + graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null); + eps = candidates.popUntilNearestKNodes(); + scratchPerLevel[i] = new NeighborArray(Math.max(candidates.k(), M + 1), false); + popToScratch(candidates, scratchPerLevel[i]); + } + + // then do connections from bottom up + for (int i = 0; i < scratchPerLevel.length; i++) { + addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], scorer); + } + lowestUnsetLevel += scratchPerLevel.length; + assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1; + if (lowestUnsetLevel > nodeLevel) { + return; + } + assert lowestUnsetLevel == curMaxLevel + 1 && nodeLevel > curMaxLevel; + if (hnsw.tryPromoteNewEntryNode(node, nodeLevel, curMaxLevel)) { + return; + } + if (hnsw.numLevels() == curMaxLevel + 1) { + // This should never happen if all the calculations are correct + throw new IllegalStateException( + "We're not able to promote node " + + node + + " at level " + + nodeLevel + + " as entry node. But the max graph level " + + curMaxLevel + + " has not changed while we are inserting the node." + ); + } + } while (true); + } + + public void addGraphNode(int node) throws IOException { + /* + Note: this implementation is thread safe when graph size is fixed (e.g. when merging) + The process of adding a node is roughly: + 1. Add the node to all level from top to the bottom, but do not connect it to any other node, + nor try to promote itself to an entry node before the connection is done. (Unless the graph is empty + and this is the first node, in that case we set the entry node and return) + 2. Do the search from top to bottom, remember all the possible neighbours on each level the node + is on. + 3. Add the neighbor to the node from bottom to top level, when adding the neighbour, + we always add all the outgoing links first before adding incoming link such that + when a search visits this node, it can always find a way out + 4. If the node has level that is less or equal to graph level, then we're done here. + If the node has level larger than graph level, then we need to promote the node + as the entry node. If, while we add the node to the graph, the entry node has changed + (which means the graph level has changed as well), we need to reinsert the node + to the newly introduced levels (repeating step 2,3 for new levels) and again try to + promote the node to entry node. + */ + UpdateableRandomVectorScorer scorer = scorerSupplier.scorer(); + scorer.setScoringOrdinal(node); + addGraphNodeInternal(node, scorer, null); + } + + public void addGraphNodeWithEps(int node, IntHashSet eps0) throws IOException { + UpdateableRandomVectorScorer scorer = scorerSupplier.scorer(); + scorer.setScoringOrdinal(node); + addGraphNodeInternal(node, scorer, eps0); + } + + private long printGraphBuildStatus(int node, long start, long t) { + long now = System.nanoTime(); + infoStream.message( + HNSW_COMPONENT, + String.format( + Locale.ROOT, + "built %d in %d/%d ms", + node, + TimeUnit.NANOSECONDS.toMillis(now - t), + TimeUnit.NANOSECONDS.toMillis(now - start) + ) + ); + return now; + } + + private void addDiverseNeighbors(int level, int node, NeighborArray candidates, UpdateableRandomVectorScorer scorer) + throws IOException { + /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it + * is closer to target than it is to any of the already-selected neighbors (ie selected in this method, + * since the node is new and has no prior neighbors). + */ + NeighborArray neighbors = hnsw.getNeighbors(level, node); + assert neighbors.size() == 0; // new node + int maxConnOnLevel = level == 0 ? M * 2 : M; + boolean[] mask = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel, scorer); + + // Link the selected nodes to the new node, and the new node to the selected nodes (again + // applying diversity heuristic) + // NOTE: here we're using candidates and mask but not the neighbour array because once we have + // added incoming link there will be possibilities of this node being discovered and neighbour + // array being modified. So using local candidates and mask is a safer option. + for (int i = 0; i < candidates.size(); i++) { + if (mask[i] == false) { + continue; + } + int nbr = candidates.nodes()[i]; + if (hnswLock != null) { + Lock lock = hnswLock.write(level, nbr); + try { + NeighborArray nbrsOfNbr = getGraph().getNeighbors(level, nbr); + nbrsOfNbr.addAndEnsureDiversity(node, candidates.getScores(i), nbr, scorer); + } finally { + lock.unlock(); + } + } else { + NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); + nbrsOfNbr.addAndEnsureDiversity(node, candidates.getScores(i), nbr, scorer); + } + } + } + + /** + * This method will select neighbors to add and return a mask telling the caller which candidates + * are selected + */ + private boolean[] selectAndLinkDiverse( + NeighborArray neighbors, + NeighborArray candidates, + int maxConnOnLevel, + UpdateableRandomVectorScorer scorer + ) throws IOException { + boolean[] mask = new boolean[candidates.size()]; + // Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic + for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) { + // compare each neighbor (in distance order) against the closer neighbors selected so far, + // only adding it if it is closer to the target than to any of the other selected neighbors + int cNode = candidates.nodes()[i]; + float cScore = candidates.getScores(i); + assert cNode <= hnsw.maxNodeId(); + scorer.setScoringOrdinal(cNode); + if (diversityCheck(cScore, neighbors, scorer)) { + mask[i] = true; + // here we don't need to lock, because there's no incoming link so no others is able to + // discover this node such that no others will modify this neighbor array as well + neighbors.addInOrder(cNode, cScore); + } + } + return mask; + } + + private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) { + scratch.clear(); + int candidateCount = candidates.size(); + // extract all the Neighbors from the queue into an array; these will now be + // sorted from worst to best + for (int i = 0; i < candidateCount; i++) { + float maxSimilarity = candidates.minimumScore(); + scratch.addInOrder(candidates.popNode(), maxSimilarity); + } + } + + /** + * @param score the score of the new candidate and node n, to be compared with scores of the + * candidate and n's neighbors + * @param neighbors the neighbors selected so far + * @return whether the candidate is diverse given the existing neighbors + */ + private boolean diversityCheck(float score, NeighborArray neighbors, RandomVectorScorer scorer) throws IOException { + for (int i = 0; i < neighbors.size(); i++) { + float neighborSimilarity = scorer.score(neighbors.nodes()[i]); + if (neighborSimilarity >= score) { + return false; + } + } + return true; + } + + private static int getRandomGraphLevel(double ml, SplittableRandom random) { + double randDouble; + do { + randDouble = random.nextDouble(); // avoid 0 value, as log(0) is undefined + } while (randDouble == 0.0); + return ((int) (-log(randDouble) * ml)); + } + + void finish() throws IOException { + // System.out.println("finish " + frozen); + // TODO: Connect components can be exceptionally expensive, disabling + // see: https://github.com/apache/lucene/issues/14214 + // connectComponents(); + frozen = true; + hnsw.finishBuild(); + } + + /** + * A restricted, specialized knnCollector that can be used when building a graph. + * + *

Does not support TopDocs + */ + public static final class GraphBuilderKnnCollector implements KnnCollector { + private final NeighborQueue queue; + private final int k; + private long visitedCount; + + /** + * @param k the number of neighbors to collect + */ + public GraphBuilderKnnCollector(int k) { + this.queue = new NeighborQueue(k, false); + this.k = k; + } + + public int size() { + return queue.size(); + } + + public int popNode() { + return queue.pop(); + } + + public int[] popUntilNearestKNodes() { + while (size() > k()) { + queue.pop(); + } + return queue.nodes(); + } + + public float minimumScore() { + return queue.topScore(); + } + + public void clear() { + this.queue.clear(); + this.visitedCount = 0; + } + + @Override + public boolean earlyTerminated() { + return false; + } + + @Override + public void incVisitedCount(int count) { + this.visitedCount += count; + } + + @Override + public long visitedCount() { + return visitedCount; + } + + @Override + public long visitLimit() { + return Long.MAX_VALUE; + } + + @Override + public int k() { + return k; + } + + @Override + public boolean collect(int docId, float similarity) { + return queue.insertWithOverflow(docId, similarity); + } + + @Override + public float minCompetitiveSimilarity() { + return queue.size() >= k() ? queue.topScore() : Float.NEGATIVE_INFINITY; + } + + @Override + public TopDocs topDocs() { + throw new IllegalArgumentException(); + } + + @Override + public KnnSearchStrategy getSearchStrategy() { + return null; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/HnswGraphSearcher.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/HnswGraphSearcher.java new file mode 100644 index 0000000000000..7306fbdf74e68 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/HnswGraphSearcher.java @@ -0,0 +1,224 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es910.hnsw; + +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.NeighborQueue; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.io.IOException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the + * search algorithm, see {@link HnswGraph}. + */ +public class HnswGraphSearcher extends AbstractHnswGraphSearcher { + /** + * Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive + * to allocate, so they're cleared and reused across calls. + */ + protected final NeighborQueue candidates; + + protected BitSet visited; + + /** + * Creates a new graph searcher. + * + * @param candidates max heap that will track the candidate nodes to explore + * @param visited bit set that will track nodes that have already been visited + */ + public HnswGraphSearcher(NeighborQueue candidates, BitSet visited) { + this.candidates = candidates; + this.visited = visited; + } + + /** + * Function to find the best entry point from which to search the zeroth graph layer. + * + * @param scorer the scorer to compare the query with the nodes + * @param graph the HNSWGraph + * @param collector the knn result collector + * @return the best entry point, `-1` indicates graph entry node not set, or visitation limit + * exceeded + * @throws IOException When accessing the vector fails + */ + @Override + int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollector collector) throws IOException { + int currentEp = graph.entryNode(); + if (currentEp == -1 || graph.numLevels() == 1) { + return new int[] { currentEp }; + } + int size = getGraphSize(graph); + prepareScratchState(size); + float currentScore = scorer.score(currentEp); + collector.incVisitedCount(1); + boolean foundBetter; + for (int level = graph.numLevels() - 1; level >= 1; level--) { + foundBetter = true; + visited.set(currentEp); + // Keep searching the given level until we stop finding a better candidate entry point + while (foundBetter) { + foundBetter = false; + graphSeek(graph, level, currentEp); + int friendOrd; + while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) { + assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; + if (visited.getAndSet(friendOrd)) { + continue; + } + if (collector.earlyTerminated()) { + return new int[] { UNK_EP }; + } + float friendSimilarity = scorer.score(friendOrd); + collector.incVisitedCount(1); + if (friendSimilarity > currentScore) { + currentScore = friendSimilarity; + currentEp = friendOrd; + foundBetter = true; + } + } + } + } + return collector.earlyTerminated() ? new int[] { UNK_EP } : new int[] { currentEp }; + } + + /** + * Add the closest neighbors found to a priority queue (heap). These are returned in REVERSE + * proximity order -- the most distant neighbor of the topK found, i.e. the one with the lowest + * score/comparison value, will be at the top of the heap, while the closest neighbor will be the + * last to be popped. + */ + @Override + void searchLevel(KnnCollector results, RandomVectorScorer scorer, int level, final int[] eps, HnswGraph graph, Bits acceptOrds) + throws IOException { + + int size = getGraphSize(graph); + + prepareScratchState(size); + + for (int ep : eps) { + if (visited.getAndSet(ep) == false) { + if (results.earlyTerminated()) { + break; + } + float score = scorer.score(ep); + results.incVisitedCount(1); + candidates.add(ep, score); + if (acceptOrds == null || acceptOrds.get(ep)) { + results.collect(ep, score); + } + } + } + + // A bound that holds the minimum similarity to the query vector that a candidate vector must + // have to be considered. + float minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); + // We should allow exploring equivalent minAcceptedSimilarity values at least once + boolean shouldExploreMinSim = true; + while (candidates.size() > 0 && results.earlyTerminated() == false) { + // get the best candidate (closest or best scoring) + float topCandidateSimilarity = candidates.topScore(); + if (topCandidateSimilarity < minAcceptedSimilarity) { + // if the similarity is equivalent to the minAcceptedSimilarity, + // we should explore one candidate + // however, running into many duplicates can be expensive, + // so we should stop exploring if equivalent minimum scores are found + if (shouldExploreMinSim && Math.nextUp(topCandidateSimilarity) == minAcceptedSimilarity) { + shouldExploreMinSim = false; + } else { + break; + } + } + + int topCandidateNode = candidates.pop(); + graphSeek(graph, level, topCandidateNode); + int friendOrd; + while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) { + assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; + if (visited.getAndSet(friendOrd)) { + continue; + } + + if (results.earlyTerminated()) { + break; + } + float friendSimilarity = scorer.score(friendOrd); + results.incVisitedCount(1); + if (friendSimilarity >= minAcceptedSimilarity) { + candidates.add(friendOrd, friendSimilarity); + if (acceptOrds == null || acceptOrds.get(friendOrd)) { + if (results.collect(friendOrd, friendSimilarity)) { + float oldMinAcceptedSimilarity = minAcceptedSimilarity; + minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); + if (minAcceptedSimilarity > oldMinAcceptedSimilarity) { + // we adjusted our minAcceptedSimilarity, so we should explore the next equivalent + // if necessary + shouldExploreMinSim = true; + } + } + } + } + } + if (results.getSearchStrategy() != null) { + results.getSearchStrategy().nextVectorsBlock(); + } + } + } + + private void prepareScratchState(int capacity) { + candidates.clear(); + if (visited.length() < capacity) { + visited = FixedBitSet.ensureCapacity((FixedBitSet) visited, capacity); + } + visited.clear(); + } + + /** + * Seek a specific node in the given graph. The default implementation will just call {@link + * HnswGraph#seek(int, int)} + * + * @throws IOException when seeking the graph + */ + void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException { + graph.seek(level, targetNode); + } + + /** + * Get the next neighbor from the graph, you must call {@link #graphSeek(HnswGraph, int, int)} + * before calling this method. The default implementation will just call {@link + * HnswGraph#nextNeighbor()} + * + * @return see {@link HnswGraph#nextNeighbor()} + * @throws IOException when advance neighbors + */ + int graphNextNeighbor(HnswGraph graph) throws IOException { + return graph.nextNeighbor(); + } + + static int getGraphSize(HnswGraph graph) { + return graph.maxNodeId() + 1; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/HnswLock.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/HnswLock.java new file mode 100644 index 0000000000000..4137663a5cb11 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/HnswLock.java @@ -0,0 +1,54 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910.hnsw; + +import org.apache.lucene.util.hnsw.HnswConcurrentMergeBuilder; +import org.apache.lucene.util.hnsw.OnHeapHnswGraph; + +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * Provide (read-and-write) striped locks for access to nodes of an {@link OnHeapHnswGraph}. For use + * by {@link HnswConcurrentMergeBuilder} and its HnswGraphBuilders. + */ +final class HnswLock { + private static final int NUM_LOCKS = 512; + private final ReentrantReadWriteLock[] locks; + + HnswLock() { + locks = new ReentrantReadWriteLock[NUM_LOCKS]; + for (int i = 0; i < NUM_LOCKS; i++) { + locks[i] = new ReentrantReadWriteLock(); + } + } + + Lock write(int level, int node) { + int lockid = hash(level, node) % NUM_LOCKS; + Lock lock = locks[lockid].writeLock(); + lock.lock(); + return lock; + } + + private static int hash(int v1, int v2) { + return v1 * 31 + v2; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/IncrementalHnswGraphMerger.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/IncrementalHnswGraphMerger.java new file mode 100644 index 0000000000000..b084ff7301fa3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/IncrementalHnswGraphMerger.java @@ -0,0 +1,195 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es910.hnsw; + +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.internal.hppc.IntIntHashMap; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * This merges multiple graphs in a single thread in incremental fashion. + */ +public class IncrementalHnswGraphMerger { + + protected final FieldInfo fieldInfo; + protected final RandomVectorScorerSupplier scorerSupplier; + protected final int M; + protected final int beamWidth; + + protected List graphReaders = new ArrayList<>(); + private int numReaders = 0; + + /** Represents a vector reader that contains graph info. */ + protected record GraphReader(KnnVectorsReader reader, MergeState.DocMap initDocMap, int graphSize) {} + + /** + * @param fieldInfo FieldInfo for the field being merged + */ + public IncrementalHnswGraphMerger(FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth) { + this.fieldInfo = fieldInfo; + this.scorerSupplier = scorerSupplier; + this.M = M; + this.beamWidth = beamWidth; + } + + /** + * Adds a reader to the graph merger if it meets the following criteria: 1. does not contain any + * deleted docs 2. is a HnswGraphProvider + */ + public IncrementalHnswGraphMerger addReader(KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs) throws IOException { + numReaders++; + if (hasDeletes(liveDocs) || (reader instanceof HnswGraphProvider == false)) { + return this; + } + HnswGraph graph = ((HnswGraphProvider) reader).getGraph(fieldInfo.name); + if (graph == null || graph.size() == 0) { + return this; + } + + int candidateVectorCount = 0; + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> { + ByteVectorValues byteVectorValues = reader.getByteVectorValues(fieldInfo.name); + if (byteVectorValues == null) { + return this; + } + candidateVectorCount = byteVectorValues.size(); + } + case FLOAT32 -> { + FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + return this; + } + candidateVectorCount = vectorValues.size(); + } + } + graphReaders.add(new GraphReader(reader, docMap, candidateVectorCount)); + return this; + } + + /** + * Builds a new HnswGraphBuilder + * + * @param mergedVectorValues vector values in the merged segment + * @param maxOrd max num of vectors that will be merged into the graph + * @return HnswGraphBuilder + * @throws IOException If an error occurs while reading from the merge state + */ + protected HnswGraphBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd) throws IOException { + if (graphReaders.size() == 0) { + return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, maxOrd); + } + graphReaders.sort(Comparator.comparingInt(GraphReader::graphSize).reversed()); + + final BitSet initializedNodes = graphReaders.size() == numReaders ? null : new FixedBitSet(maxOrd); + int[][] ordMaps = getNewOrdMapping(mergedVectorValues, initializedNodes); + HnswGraph[] graphs = new HnswGraph[graphReaders.size()]; + for (int i = 0; i < graphReaders.size(); i++) { + HnswGraph graph = ((HnswGraphProvider) graphReaders.get(i).reader).getGraph(fieldInfo.name); + if (graph.size() == 0) { + throw new IllegalStateException("Graph should not be empty"); + } + graphs[i] = graph; + } + + return MergingHnswGraphBuilder.fromGraphs( + scorerSupplier, + beamWidth, + HnswGraphBuilder.randSeed, + graphs, + ordMaps, + maxOrd, + initializedNodes + ); + } + + protected final int[][] getNewOrdMapping(KnnVectorValues mergedVectorValues, BitSet initializedNodes) throws IOException { + final int numGraphs = graphReaders.size(); + IntIntHashMap[] newDocIdToOldOrdinals = new IntIntHashMap[numGraphs]; + final int[][] oldToNewOrdinalMap = new int[numGraphs][]; + for (int i = 0; i < numGraphs; i++) { + KnnVectorValues.DocIndexIterator vectorsIter = null; + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> vectorsIter = graphReaders.get(i).reader.getByteVectorValues(fieldInfo.name).iterator(); + case FLOAT32 -> vectorsIter = graphReaders.get(i).reader.getFloatVectorValues(fieldInfo.name).iterator(); + } + newDocIdToOldOrdinals[i] = new IntIntHashMap(graphReaders.get(i).graphSize); + MergeState.DocMap docMap = graphReaders.get(i).initDocMap(); + for (int docId = vectorsIter.nextDoc(); docId != NO_MORE_DOCS; docId = vectorsIter.nextDoc()) { + int newDocId = docMap.get(docId); + newDocIdToOldOrdinals[i].put(newDocId, vectorsIter.index()); + } + oldToNewOrdinalMap[i] = new int[graphReaders.get(i).graphSize]; + } + + KnnVectorValues.DocIndexIterator mergedVectorIterator = mergedVectorValues.iterator(); + for (int docId = mergedVectorIterator.nextDoc(); docId < NO_MORE_DOCS; docId = mergedVectorIterator.nextDoc()) { + int newOrd = mergedVectorIterator.index(); + for (int i = 0; i < numGraphs; i++) { + int oldOrd = newDocIdToOldOrdinals[i].getOrDefault(docId, -1); + if (oldOrd != -1) { + oldToNewOrdinalMap[i][oldOrd] = newOrd; + if (initializedNodes != null) { + initializedNodes.set(newOrd); + } + break; + } + } + } + return oldToNewOrdinalMap; + } + + public OnHeapHnswGraph merge(KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd) throws IOException { + HnswGraphBuilder builder = createBuilder(mergedVectorValues, maxOrd); + builder.setInfoStream(infoStream); + return builder.build(maxOrd); + } + + private static boolean hasDeletes(Bits liveDocs) { + if (liveDocs == null) { + return false; + } + + for (int i = 0; i < liveDocs.length(); i++) { + if (liveDocs.get(i) == false) { + return true; + } + } + return false; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/InitializedHnswGraphBuilder.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/InitializedHnswGraphBuilder.java new file mode 100644 index 0000000000000..68bb7421c83f4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/InitializedHnswGraphBuilder.java @@ -0,0 +1,56 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910.hnsw; + +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * This creates a graph builder that is initialized with the provided HnswGraph. This is useful for + * merging HnswGraphs from multiple segments. + */ +public final class InitializedHnswGraphBuilder { + + public static OnHeapHnswGraph initGraph(HnswGraph initializerGraph, int[] newOrdMap, int totalNumberOfVectors) throws IOException { + OnHeapHnswGraph hnsw = new OnHeapHnswGraph(initializerGraph.maxConn(), totalNumberOfVectors); + for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) { + HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level); + while (it.hasNext()) { + int oldOrd = it.nextInt(); + int newOrd = newOrdMap[oldOrd]; + hnsw.addNode(level, newOrd); + hnsw.trySetNewEntryNode(newOrd, level); + NeighborArray newNeighbors = hnsw.getNeighbors(level, newOrd); + initializerGraph.seek(level, oldOrd); + for (int oldNeighbor = initializerGraph.nextNeighbor(); oldNeighbor != NO_MORE_DOCS; oldNeighbor = initializerGraph + .nextNeighbor()) { + int newNeighbor = newOrdMap[oldNeighbor]; + // we will compute these scores later when we need to pop out the non-diverse nodes + newNeighbors.addOutOfOrder(newNeighbor, Float.NaN); + } + } + } + return hnsw; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/MergingHnswGraphBuilder.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/MergingHnswGraphBuilder.java new file mode 100644 index 0000000000000..d9e12b5323581 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/MergingHnswGraphBuilder.java @@ -0,0 +1,177 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910.hnsw; + +import org.apache.lucene.internal.hppc.IntHashSet; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateGraphsUtils; + +import java.io.IOException; +import java.util.Set; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * A graph builder that is used during segments' merging. + * + *

This builder uses a smart algorithm to merge multiple graphs into a single graph. The + * algorithm is based on the idea that if we know where we want to insert a node, we have a good + * idea of where we want to insert its neighbors. + * + *

The algorithm is based on the following steps: + * + *

+ * + *

We expect the size of join set `j` to be small, around 1/5 to 1/2 of the size of gS. For the + * rest of the nodes in gS, we expect savings by performing lighter searches in gL. + */ +public final class MergingHnswGraphBuilder extends HnswGraphBuilder { + private final HnswGraph[] graphs; + private final int[][] ordMaps; + private final BitSet initializedNodes; + + private MergingHnswGraphBuilder( + RandomVectorScorerSupplier scorerSupplier, + int beamWidth, + long seed, + OnHeapHnswGraph initializedGraph, + HnswGraph[] graphs, + int[][] ordMaps, + BitSet initializedNodes + ) throws IOException { + super(scorerSupplier, beamWidth, seed, initializedGraph); + this.graphs = graphs; + this.ordMaps = ordMaps; + this.initializedNodes = initializedNodes; + } + + /** + * Create a new HnswGraphBuilder that is initialized with the provided HnswGraph. + * + * @param scorerSupplier the scorer to use for vectors + * @param beamWidth the number of nodes to explore in the search + * @param seed the seed for the random number generator + * @param graphs the graphs to merge + * @param ordMaps the ordinal maps for the graphs + * @param totalNumberOfVectors the total number of vectors in the new graph, this should include + * all vectors expected to be added to the graph in the future + * @param initializedNodes the nodes will be initialized through the merging + * @return a new HnswGraphBuilder that is initialized with the provided HnswGraph + * @throws IOException when reading the graph fails + */ + public static MergingHnswGraphBuilder fromGraphs( + RandomVectorScorerSupplier scorerSupplier, + int beamWidth, + long seed, + HnswGraph[] graphs, + int[][] ordMaps, + int totalNumberOfVectors, + BitSet initializedNodes + ) throws IOException { + OnHeapHnswGraph graph = InitializedHnswGraphBuilder.initGraph(graphs[0], ordMaps[0], totalNumberOfVectors); + return new MergingHnswGraphBuilder(scorerSupplier, beamWidth, seed, graph, graphs, ordMaps, initializedNodes); + } + + @Override + public OnHeapHnswGraph build(int maxOrd) throws IOException { + if (frozen) { + throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated"); + } + if (infoStream.isEnabled(HNSW_COMPONENT)) { + String graphSizes = ""; + for (HnswGraph g : graphs) { + graphSizes += g.size() + " "; + } + infoStream.message( + HNSW_COMPONENT, + "build graph from merging " + graphs.length + " graphs of " + maxOrd + " vectors, graph sizes:" + graphSizes + ); + } + for (int i = 1; i < graphs.length; i++) { + updateGraph(graphs[i], ordMaps[i]); + } + + // TODO: optimize to iterate only over unset bits in initializedNodes + if (initializedNodes != null) { + for (int node = 0; node < maxOrd; node++) { + if (initializedNodes.get(node) == false) { + addGraphNode(node); + } + } + } + + return getCompletedGraph(); + } + + /** Merge the smaller graph into the current larger graph. */ + private void updateGraph(HnswGraph gS, int[] ordMapS) throws IOException { + int size = gS.size(); + Set j = UpdateGraphsUtils.computeJoinSet(gS); + + // for nodes that in the join set, add them directly to the graph + for (int node : j) { + addGraphNode(ordMapS[node]); + } + + // for each node outside of j set: + // form the entry points set for the node + // by joining the node's neighbours in gS with + // the node's neighbours' neighbours in gL + for (int u = 0; u < size; u++) { + if (j.contains(u)) { + continue; + } + IntHashSet eps = new IntHashSet(); + gS.seek(0, u); + for (int v = gS.nextNeighbor(); v != NO_MORE_DOCS; v = gS.nextNeighbor()) { + // if u's neighbour v is in the join set, or already added to gL (v < u), + // then we add v's neighbours from gL to the candidate list + if (v < u || j.contains(v)) { + int newv = ordMapS[v]; + eps.add(newv); + + hnsw.seek(0, newv); + int friendOrd; + while ((friendOrd = hnsw.nextNeighbor()) != NO_MORE_DOCS) { + eps.add(friendOrd); + } + } + } + addGraphNodeWithEps(ordMapS[u], eps); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/NeighborArray.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/NeighborArray.java new file mode 100644 index 0000000000000..3f642e257dd39 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/NeighborArray.java @@ -0,0 +1,303 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910.hnsw; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.index.codec.vectors.es910.internal.hppc.MaxSizedFloatArrayList; +import org.elasticsearch.index.codec.vectors.es910.internal.hppc.MaxSizedIntArrayList; + +import java.io.IOException; +import java.util.Arrays; + +/** + * NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair + * of growable arrays. Nodes are arranged in the sorted order of their scores in descending order + * (if scoresDescOrder is true), or in the ascending order of their scores (if scoresDescOrder is + * false) + */ +public class NeighborArray implements Accountable { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(NeighborArray.class); + + private final boolean scoresDescOrder; + private int size; + private final int maxSize; + private final MaxSizedFloatArrayList scores; + private final MaxSizedIntArrayList nodes; + private int sortedNodeSize; + + public NeighborArray(int maxSize, boolean descOrder) { + this.maxSize = maxSize; + nodes = new MaxSizedIntArrayList(maxSize, maxSize / 8); + scores = new MaxSizedFloatArrayList(maxSize, maxSize / 8); + this.scoresDescOrder = descOrder; + } + + /** + * Add a new node to the NeighborArray. The new node must be worse than all previously stored + * nodes. This cannot be called after {@link #addOutOfOrder(int, float)} + */ + public void addInOrder(int newNode, float newScore) { + assert size == sortedNodeSize : "cannot call addInOrder after addOutOfOrder"; + if (size == maxSize) { + throw new IllegalStateException("No growth is allowed"); + } + if (size > 0) { + float previousScore = scores.get(size - 1); + assert ((scoresDescOrder && (previousScore >= newScore)) || (scoresDescOrder == false && (previousScore <= newScore))) + : "Nodes are added in the incorrect order! Comparing " + newScore + " to " + Arrays.toString(scores.toArray()); + } + nodes.add(newNode); + scores.add(newScore); + ++size; + ++sortedNodeSize; + } + + /** Add node and newScore but do not insert as sorted */ + public void addOutOfOrder(int newNode, float newScore) { + if (size == maxSize) { + throw new IllegalStateException("No growth is allowed"); + } + + nodes.add(newNode); + scores.add(newScore); + size++; + } + + /** + * In addition to {@link #addOutOfOrder(int, float)}, this function will also remove the + * least-diverse node if the node array is full after insertion + * + *

In multi-threading environment, this method need to be locked as it will be called by + * multiple threads while other add method is only supposed to be called by one thread. + * + * @param nodeId node Id of the owner of this NeighbourArray + */ + public void addAndEnsureDiversity(int newNode, float newScore, int nodeId, UpdateableRandomVectorScorer scorer) throws IOException { + addOutOfOrder(newNode, newScore); + if (size < maxSize) { + return; + } + // we're oversize, need to do diversity check and pop out the least diverse neighbour + scorer.setScoringOrdinal(nodeId); + removeIndex(findWorstNonDiverse(scorer)); + assert size == maxSize - 1; + } + + /** + * Sort the array according to scores, and return the sorted indexes of previous unsorted nodes + * (unchecked nodes) + * + * @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is + * already fully sorted + */ + int[] sort(RandomVectorScorer scorer) throws IOException { + if (size == sortedNodeSize) { + // all nodes checked and sorted + return null; + } + assert sortedNodeSize < size; + int[] uncheckedIndexes = new int[size - sortedNodeSize]; + int count = 0; + while (sortedNodeSize != size) { + // TODO: Instead of do an array copy on every insertion, I think we can do better here: + // Remember the insertion point of each unsorted node and insert them altogether + // We can save several array copy by doing that + uncheckedIndexes[count] = insertSortedInternal(scorer); // sortedNodeSize is increased inside + for (int i = 0; i < count; i++) { + if (uncheckedIndexes[i] >= uncheckedIndexes[count]) { + // the previous inserted nodes has been shifted + uncheckedIndexes[i]++; + } + } + count++; + } + Arrays.sort(uncheckedIndexes); + return uncheckedIndexes; + } + + /** insert the first unsorted node into its sorted position */ + private int insertSortedInternal(RandomVectorScorer scorer) throws IOException { + assert sortedNodeSize < size : "Call this method only when there's unsorted node"; + int tmpNode = nodes.get(sortedNodeSize); + float tmpScore = scores.get(sortedNodeSize); + + if (Float.isNaN(tmpScore)) { + tmpScore = scorer.score(tmpNode); + } + + int insertionPoint = scoresDescOrder + ? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize) + : ascSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize); + System.arraycopy(nodes.buffer, insertionPoint, nodes.buffer, insertionPoint + 1, sortedNodeSize - insertionPoint); + System.arraycopy(scores.buffer, insertionPoint, scores.buffer, insertionPoint + 1, sortedNodeSize - insertionPoint); + nodes.buffer[insertionPoint] = tmpNode; + scores.buffer[insertionPoint] = tmpScore; + ++sortedNodeSize; + return insertionPoint; + } + + /** This method is for test only. */ + void insertSorted(int newNode, float newScore) throws IOException { + addOutOfOrder(newNode, newScore); + insertSortedInternal(null); + } + + public int size() { + return size; + } + + /** + * Direct access to the internal list of node ids; provided for efficient writing of the graph + */ + public int[] nodes() { + return nodes.buffer; + } + + /** + * Get the score at the given index + * + * @param i index of the score to get + * @return the score at the given index + */ + public float getScores(int i) { + return scores.get(i); + } + + public void clear() { + size = 0; + sortedNodeSize = 0; + nodes.clear(); + scores.clear(); + } + + void removeLast() { + nodes.removeLast(); + scores.removeLast(); + size--; + sortedNodeSize = Math.min(sortedNodeSize, size); + } + + void removeIndex(int idx) { + if (idx == size - 1) { + removeLast(); + return; + } + nodes.removeAt(idx); + scores.removeAt(idx); + if (idx < sortedNodeSize) { + sortedNodeSize--; + } + size--; + } + + @Override + public String toString() { + return "NeighborArray[" + size + "]"; + } + + private int ascSortFindRightMostInsertionPoint(float newScore, int bound) { + int insertionPoint = Arrays.binarySearch(scores.buffer, 0, bound, newScore); + if (insertionPoint >= 0) { + // find the right most position with the same score + while ((insertionPoint < bound - 1) && (scores.get(insertionPoint + 1) == scores.get(insertionPoint))) { + insertionPoint++; + } + insertionPoint++; + } else { + insertionPoint = -insertionPoint - 1; + } + return insertionPoint; + } + + private int descSortFindRightMostInsertionPoint(float newScore, int bound) { + int start = 0; + int end = bound - 1; + while (start <= end) { + int mid = (start + end) / 2; + if (scores.get(mid) < newScore) end = mid - 1; + else start = mid + 1; + } + return start; + } + + /** + * Find first non-diverse neighbour among the list of neighbors starting from the most distant + * neighbours + */ + private int findWorstNonDiverse(UpdateableRandomVectorScorer scorer) throws IOException { + int[] uncheckedIndexes = sort(scorer); + assert uncheckedIndexes != null : "We will always have something unchecked"; + int uncheckedCursor = uncheckedIndexes.length - 1; + for (int i = size - 1; i > 0; i--) { + if (uncheckedCursor < 0) { + // no unchecked node left + break; + } + scorer.setScoringOrdinal(nodes.get(i)); + if (isWorstNonDiverse(i, uncheckedIndexes, uncheckedCursor, scorer)) { + return i; + } + if (i == uncheckedIndexes[uncheckedCursor]) { + uncheckedCursor--; + } + } + return size - 1; + } + + private boolean isWorstNonDiverse(int candidateIndex, int[] uncheckedIndexes, int uncheckedCursor, RandomVectorScorer scorer) + throws IOException { + float minAcceptedSimilarity = scores.get(candidateIndex); + if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { + // the candidate itself is unchecked + for (int i = candidateIndex - 1; i >= 0; i--) { + float neighborSimilarity = scorer.score(nodes.get(i)); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } + } + } else { + // else we just need to make sure candidate does not violate diversity with the (newly + // inserted) unchecked nodes + assert candidateIndex > uncheckedIndexes[uncheckedCursor]; + for (int i = uncheckedCursor; i >= 0; i--) { + float neighborSimilarity = scorer.score(nodes.get(uncheckedIndexes[i])); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } + } + } + return false; + } + + public int maxSize() { + return maxSize; + } + + @Override + public long ramBytesUsed() { + return BASE_RAM_BYTES_USED + nodes.ramBytesUsed() + scores.ramBytesUsed(); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/OnHeapHnswGraph.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/OnHeapHnswGraph.java new file mode 100644 index 0000000000000..68b00f26924bc --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/hnsw/OnHeapHnswGraph.java @@ -0,0 +1,339 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910.hnsw; + +import org.apache.lucene.internal.hppc.IntArrayList; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * An {@link org.apache.lucene.util.hnsw.HnswGraph} where all nodes and connections are held in memory. This class is used to + * construct the HNSW graph before it's written to the index. + */ +public final class OnHeapHnswGraph extends HnswGraph implements Accountable { + + private static final int INIT_SIZE = 128; + + private final AtomicReference entryNode; + + // the internal graph representation where the first dimension is node id and second dimension is + // level + // e.g. graph[1][2] is all the neighbours of node 1 at level 2 + private NeighborArray[][] graph; + // essentially another 2d map which the first dimension is level and second dimension is node id, + // this is only + // generated on demand when there's someone calling getNodeOnLevel on a non-zero level + private IntArrayList[] levelToNodes; + private int lastFreezeSize; // remember the size we are at last time to freeze the graph and generate + // levelToNodes + private final AtomicInteger size = new AtomicInteger(0); // graph size, which is number of nodes in level 0 + private final AtomicInteger nonZeroLevelSize = new AtomicInteger(0); // total number of NeighborArrays created that is not on level 0, + // for now it + // is only used to account memory usage + private final AtomicInteger maxNodeId = new AtomicInteger(-1); + private final int nsize; // neighbour array size at non-zero level + private final int nsize0; // neighbour array size at zero level + private final boolean noGrowth; // if an initial size is passed in, we don't expect the graph to grow itself + + // KnnGraphValues iterator members + private int upto; + private NeighborArray cur; + + private volatile long graphRamBytesUsed; + + /** + * ctor + * + * @param numNodes number of nodes that will be added to this graph, passing in -1 means unbounded + * while passing in a non-negative value will lock the whole graph and disable the graph from + * growing itself (you cannot add a node with id >= numNodes) + */ + OnHeapHnswGraph(int M, int numNodes) { + this.entryNode = new AtomicReference<>(new EntryNode(-1, 1)); + // Neighbours' size on upper levels (nsize) and level 0 (nsize0) + // We allocate extra space for neighbours, but then prune them to keep allowed maximum + this.nsize = M + 1; + this.nsize0 = (M * 2 + 1); + noGrowth = numNodes != -1; + if (noGrowth == false) { + numNodes = INIT_SIZE; + } + this.graph = new NeighborArray[numNodes][]; + } + + /** + * Returns the {@link org.apache.lucene.util.hnsw.NeighborArray} connected to the given node. + * + * @param level level of the graph + * @param node the node whose neighbors are returned, represented as an ordinal on the level 0. + */ + public NeighborArray getNeighbors(int level, int node) { + assert node < graph.length; + assert level < graph[node].length + : "level=" + level + ", node " + node + " has only " + graph[node].length + " levels for graph " + this; + assert graph[node][level] != null : "node=" + node + ", level=" + level; + return graph[node][level]; + } + + @Override + public int size() { + return size.get(); + } + + /** + * When we initialize from another graph, the max node id is different from {@link #size()}, + * because we will add nodes out of order, such that we need two method for each + * + * @return max node id (inclusive) + */ + @Override + public int maxNodeId() { + if (noGrowth) { + // we know the eventual graph size and the graph can possibly + // being concurrently modified + return graph.length - 1; + } else { + // The graph cannot be concurrently modified (and searched) if + // we don't know the size beforehand, so it's safe to return the + // actual maxNodeId + return maxNodeId.get(); + } + } + + /** + * Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes + * preceded by the node inserted out of order are eventually added. + * + *

NOTE: You must add a node starting from the node's top level + * + * @param level level to add a node on + * @param node the node to add, represented as an ordinal on the level 0. + */ + public void addNode(int level, int node) { + + if (node >= graph.length) { + if (noGrowth) { + throw new IllegalStateException("The graph does not expect to grow when an initial size is given"); + } + graph = ArrayUtil.grow(graph, node + 1); + } + + assert graph[node] == null || graph[node].length > level : "node must be inserted from the top level"; + if (graph[node] == null) { + graph[node] = new NeighborArray[level + 1]; // assumption: we always call this function from top level + size.incrementAndGet(); + } + if (level == 0) { + graph[node][level] = new NeighborArray(nsize0, true); + } else { + graph[node][level] = new NeighborArray(nsize, true); + nonZeroLevelSize.incrementAndGet(); + } + maxNodeId.accumulateAndGet(node, Math::max); + // update graphRamBytesUsed every 1000 nodes + if (level == 0 && node % 1000 == 0) { + updateGraphRamBytesUsed(); + } + } + + /** Finish building the graph. */ + public void finishBuild() { + updateGraphRamBytesUsed(); + } + + @Override + public void seek(int level, int targetNode) { + cur = getNeighbors(level, targetNode); + upto = -1; + } + + @Override + public int neighborCount() { + return cur.size(); + } + + @Override + public int nextNeighbor() { + if (++upto < cur.size()) { + return cur.nodes()[upto]; + } + return NO_MORE_DOCS; + } + + /** + * Returns the current number of levels in the graph + * + * @return the current number of levels in the graph + */ + @Override + public int numLevels() { + return entryNode.get().level + 1; + } + + @Override + public int maxConn() { + return nsize - 1; + } + + /** + * Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th + * level + * + * @return the graph's current entry node on the top level + */ + @Override + public int entryNode() { + return entryNode.get().node; + } + + /** + * Try to set the entry node if the graph does not have one + * + * @return True if the entry node is set to the provided node. False if the entry node already + * exists + */ + public boolean trySetNewEntryNode(int node, int level) { + EntryNode current = entryNode.get(); + if (current.node == -1) { + return entryNode.compareAndSet(current, new EntryNode(node, level)); + } + return false; + } + + /** + * Try to promote the provided node to the entry node + * + * @param level should be larger than expectedOldLevel + * @param expectOldLevel is the old entry node level the caller expect to be, the actual graph + * level can be different due to concurrent modification + * @return True if the entry node is set to the provided node. False if expectOldLevel is not the + * same as the current entry node level. Even if the provided node's level is still higher + * than the current entry node level, the new entry node will not be set and false will be + * returned. + */ + public boolean tryPromoteNewEntryNode(int node, int level, int expectOldLevel) { + assert level > expectOldLevel; + EntryNode currentEntry = entryNode.get(); + if (currentEntry.level == expectOldLevel) { + return entryNode.compareAndSet(currentEntry, new EntryNode(node, level)); + } + return false; + } + + /** + * WARN: calling this method will essentially iterate through all nodes at level 0 (even if you're + * not getting node at level 0), we have built some caching mechanism such that if graph is not + * changed only the first non-zero level call will pay the cost. So it is highly NOT recommended + * to call this method while the graph is still building. + * + *

NOTE: calling this method while the graph is still building is prohibited + */ + @Override + public NodesIterator getNodesOnLevel(int level) { + if (size() != maxNodeId() + 1) { + throw new IllegalStateException("graph build not complete, size=" + size() + " maxNodeId=" + maxNodeId()); + } + if (level == 0) { + return new ArrayNodesIterator(size()); + } else { + generateLevelToNodes(); + return new CollectionNodesIterator(levelToNodes[level]); + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + private void generateLevelToNodes() { + if (lastFreezeSize == size()) { + return; + } + int maxLevels = numLevels(); + levelToNodes = new IntArrayList[maxLevels]; + for (int i = 1; i < maxLevels; i++) { + levelToNodes[i] = new IntArrayList(); + } + int nonNullNode = 0; + for (int node = 0; node < graph.length; node++) { + // when we init from another graph, we could have holes where some slot is null + if (graph[node] == null) { + continue; + } + nonNullNode++; + for (int i = 1; i < graph[node].length; i++) { + levelToNodes[i].add(node); + } + if (nonNullNode == size()) { + break; + } + } + lastFreezeSize = size(); + } + + /** Update the estimated ram bytes used for the neighbor array. */ + public void updateGraphRamBytesUsed() { + long currentRamBytesUsedEstimate = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; + for (int node = 0; node < graph.length; node++) { + if (graph[node] == null) { + continue; + } + + for (int i = 0; i < graph[node].length; i++) { + if (graph[node][i] == null) { + continue; + } + currentRamBytesUsedEstimate += graph[node][i].ramBytesUsed(); + } + + currentRamBytesUsedEstimate += RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; + } + graphRamBytesUsed = currentRamBytesUsedEstimate; + } + + @Override + public long ramBytesUsed() { + long total = graphRamBytesUsed; // all NeighborArray + total += 4 * Integer.BYTES; // all int fields + total += 1; // field: noGrowth + total += RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + 2 * Integer.BYTES; // field: entryNode + total += 3L * (Integer.BYTES + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER); // 3 AtomicInteger + total += RamUsageEstimator.NUM_BYTES_OBJECT_REF; // field: cur + total += RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // field: levelToNodes + if (levelToNodes != null) { + total += (long) (numLevels() - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0 + total += (long) nonZeroLevelSize.get() * (RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + Integer.BYTES); + } + return total; + } + + @Override + public String toString() { + return "OnHeapHnswGraph(size=" + size() + ", numLevels=" + numLevels() + ", entryNode=" + entryNode() + ")"; + } + + private record EntryNode(int node, int level) {} +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/internal/hppc/MaxSizedFloatArrayList.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/internal/hppc/MaxSizedFloatArrayList.java new file mode 100644 index 0000000000000..9abad12dc1676 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/internal/hppc/MaxSizedFloatArrayList.java @@ -0,0 +1,89 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910.internal.hppc; + +import org.apache.lucene.internal.hppc.BitMixer; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.index.codec.vectors.es910.util.ArrayUtil; + +/** + * An array-backed list of {@code float} with a maximum size limit. + */ +public class MaxSizedFloatArrayList extends FloatArrayList { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MaxSizedFloatArrayList.class); + + final int maxSize; + + /** + * New instance with sane defaults. + * + * @param maxSize The maximum size this list can grow to + * @param expectedElements The expected number of elements guaranteed not to cause buffer + * expansion (inclusive). + */ + public MaxSizedFloatArrayList(int maxSize, int expectedElements) { + super(expectedElements); + assert expectedElements <= maxSize : "expectedElements (" + expectedElements + ") must be <= maxSize (" + maxSize + ")"; + this.maxSize = maxSize; + } + + @Override + protected void ensureBufferSpace(int expectedAdditions) { + if (elementsCount + expectedAdditions > maxSize) { + throw new IllegalStateException("Cannot grow beyond maxSize: " + maxSize); + } + if (elementsCount + expectedAdditions > buffer.length) { + this.buffer = ArrayUtil.growInRange(buffer, elementsCount + expectedAdditions, maxSize); + } + } + + @Override + public int hashCode() { + int h = 1, max = elementsCount; + h = 31 * h + maxSize; + for (int i = 0; i < max; i++) { + h = 31 * h + BitMixer.mix(this.buffer[i]); + } + return h; + } + + /** + * Returns true only if the other object is an instance of the same class and with + * the same elements and maxSize. + */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + MaxSizedFloatArrayList other = (MaxSizedFloatArrayList) obj; + return maxSize == other.maxSize && super.equals(obj); + } + + @Override + public long ramBytesUsed() { + return BASE_RAM_BYTES_USED + RamUsageEstimator.sizeOf(buffer); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/internal/hppc/MaxSizedIntArrayList.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/internal/hppc/MaxSizedIntArrayList.java new file mode 100644 index 0000000000000..2eb150a4d23d7 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/internal/hppc/MaxSizedIntArrayList.java @@ -0,0 +1,90 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910.internal.hppc; + +import org.apache.lucene.internal.hppc.BitMixer; +import org.apache.lucene.internal.hppc.IntArrayList; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.RamUsageEstimator; + +/** + * An array-backed list of {@code int} with a maximum size limit. + */ +public class MaxSizedIntArrayList extends IntArrayList { + + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MaxSizedIntArrayList.class); + + final int maxSize; + + /** + * New instance with sane defaults. + * + * @param maxSize The maximum size this list can grow to + * @param expectedElements The expected number of elements guaranteed not to cause buffer + * expansion (inclusive). + */ + public MaxSizedIntArrayList(int maxSize, int expectedElements) { + super(expectedElements); + assert expectedElements <= maxSize : "expectedElements (" + expectedElements + ") must be <= maxSize (" + maxSize + ")"; + this.maxSize = maxSize; + } + + @Override + protected void ensureBufferSpace(int expectedAdditions) { + if (elementsCount + expectedAdditions > maxSize) { + throw new IllegalStateException("Cannot grow beyond maxSize: " + maxSize); + } + if (elementsCount + expectedAdditions > buffer.length) { + this.buffer = ArrayUtil.growInRange(buffer, elementsCount + expectedAdditions, maxSize); + } + } + + @Override + public int hashCode() { + int h = 1, max = elementsCount; + h = 31 * h + maxSize; + for (int i = 0; i < max; i++) { + h = 31 * h + BitMixer.mix(this.buffer[i]); + } + return h; + } + + /** + * Returns true only if the other object is an instance of the same class and with + * the same elements and maxSize. + */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + MaxSizedIntArrayList other = (MaxSizedIntArrayList) obj; + return maxSize == other.maxSize && super.equals(obj); + } + + @Override + public long ramBytesUsed() { + return BASE_RAM_BYTES_USED + RamUsageEstimator.sizeOf(buffer); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/util/ArrayUtil.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/util/ArrayUtil.java new file mode 100644 index 0000000000000..621584e125991 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es910/util/ArrayUtil.java @@ -0,0 +1,44 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors.es910.util; + +import static org.apache.lucene.util.ArrayUtil.growExact; +import static org.apache.lucene.util.ArrayUtil.oversize; + +public class ArrayUtil { + + public static float[] growInRange(float[] array, int minLength, int maxLength) { + assert minLength >= 0 : "minLength must be positive (got " + minLength + "): likely integer overflow?"; + + if (minLength > maxLength) { + throw new IllegalArgumentException( + "requested minimum array length " + minLength + " is larger than requested maximum array length " + maxLength + ); + } + + if (array.length >= minLength) { + return array; + } + + int potentialLength = oversize(minLength, Float.BYTES); + return growExact(array, Math.min(maxLength, potentialLength)); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 40c69ec6e7fd4..9705ced906489 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -52,6 +52,7 @@ import org.elasticsearch.index.codec.vectors.IVFVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -2003,7 +2004,7 @@ public KnnVectorsFormat getVectorsFormat(ElementType elementType) { if (elementType == ElementType.BIT) { return new ES815HnswBitVectorsFormat(m, efConstruction); } - return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null); + return new ES910HnswVectorsFormat(m, efConstruction); } @Override diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 14e68029abc3b..a7f5daf2c5e1e 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -7,4 +7,5 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es910.ES910HnswVectorsFormat org.elasticsearch.index.codec.vectors.IVFVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es910/ES910HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es910/ES910HnswVectorsFormatTests.java new file mode 100644 index 0000000000000..bec849721a698 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es910/ES910HnswVectorsFormatTests.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es910; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; + +public class ES910HnswVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES910HnswVectorsFormat()); + + @Override + protected Codec getCodec() { + return codec; + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 7acb97ffc3e51..05d29b9bca1fd 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -2171,7 +2171,7 @@ public void testKnnVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=" + String expectedString = "ES910HnswReducedHeapVectorsFormat(name=ES910HnswReducedHeapVectorsFormat, maxConn=" + (setM ? m : DEFAULT_MAX_CONN) + ", beamWidth=" + (setEfConstruction ? efConstruction : DEFAULT_BEAM_WIDTH)