diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index d2f1c528de9d..ac074d8c60f1 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -222,6 +222,8 @@ Optimizations # GITHUB#15303: Speed up NumericUtils#{add,subtract} by operating on integers instead of bytes. (Kaival Parikh) +* GITHUB#15003: Avoid reconstructing HNSW graph during merging (Pulkit Gupta) + Bug Fixes --------------------- * GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java index f6b88424a76e..4b6244c18522 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java @@ -19,7 +19,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import java.util.Comparator; +import java.util.Arrays; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.index.FieldInfo; @@ -57,14 +57,12 @@ protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxO OnHeapHnswGraph graph; BitSet initializedNodes = null; - if (graphReaders.size() == 0) { + if (largestGraphReader == null) { graph = new OnHeapHnswGraph(M, maxOrd); } else { - graphReaders.sort(Comparator.comparingInt(GraphReader::graphSize).reversed()); - GraphReader initGraphReader = graphReaders.get(0); - KnnVectorsReader initReader = initGraphReader.reader(); - MergeState.DocMap initDocMap = initGraphReader.initDocMap(); - int initGraphSize = initGraphReader.graphSize(); + KnnVectorsReader initReader = largestGraphReader.reader(); + MergeState.DocMap initDocMap = largestGraphReader.initDocMap(); + int initGraphSize = largestGraphReader.graphSize(); HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name); if (initializerGraph.size() == 0) { @@ -79,7 +77,9 @@ protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxO initGraphSize, mergedVectorValues, initializedNodes); - graph = InitializedHnswGraphBuilder.initGraph(initializerGraph, oldToNewOrdinalMap, maxOrd); + graph = + InitializedHnswGraphBuilder.initGraph( + initializerGraph, oldToNewOrdinalMap, maxOrd, beamWidth, scorerSupplier); } } return new HnswConcurrentMergeBuilder( @@ -117,6 +117,9 @@ private static int[] getNewOrdMapping( docId != NO_MORE_DOCS; docId = initializerIterator.nextDoc()) { int newId = initDocMap.get(docId); + if (newId == -1) { + continue; + } maxNewDocID = Math.max(newId, maxNewDocID); assert newIdToOldOrdinal.containsKey(newId) == false; newIdToOldOrdinal.put(newId, initializerIterator.index()); @@ -126,6 +129,7 @@ private static int[] getNewOrdMapping( return new int[0]; } final int[] oldToNewOrdinalMap = new int[initGraphSize]; + Arrays.fill(oldToNewOrdinalMap, -1); KnnVectorValues.DocIndexIterator mergedVectorIterator = mergedVectorValues.iterator(); for (int newDocId = mergedVectorIterator.nextDoc(); newDocId <= maxNewDocID; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index cd54443ab760..70f607d3b5d0 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -64,14 +64,14 @@ public class HnswGraphBuilder implements HnswBuilder { @SuppressWarnings("NonFinalStaticField") public static long randSeed = DEFAULT_RAND_SEED; - private final int M; // max number of connections on upper layers + protected final int M; // max number of connections on upper layers private final double ml; private final SplittableRandom random; - private final UpdateableRandomVectorScorer scorer; - private final HnswGraphSearcher graphSearcher; + protected final UpdateableRandomVectorScorer scorer; + protected final HnswGraphSearcher graphSearcher; private final GraphBuilderKnnCollector entryCandidates; // for upper levels of graph search - private final GraphBuilderKnnCollector + protected final GraphBuilderKnnCollector beamCandidates; // for levels of graph where we add the node private final GraphBuilderKnnCollector beamCandidates0; @@ -288,7 +288,7 @@ private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer, // then do connections from bottom up for (int i = 0; i < scratchPerLevel.length; i++) { - addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], scorer); + addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], scorer, false); } lowestUnsetLevel += scratchPerLevel.length; assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1; @@ -344,17 +344,22 @@ private long printGraphBuildStatus(int node, long start, long t) { return now; } - private void addDiverseNeighbors( - int level, int node, NeighborArray candidates, UpdateableRandomVectorScorer scorer) + void addDiverseNeighbors( + int level, + int node, + NeighborArray candidates, + UpdateableRandomVectorScorer scorer, + boolean outOfOrderInsertion) throws IOException { /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it * is closer to target than it is to any of the already-selected neighbors (ie selected in this method, * since the node is new and has no prior neighbors). */ NeighborArray neighbors = hnsw.getNeighbors(level, node); - assert neighbors.size() == 0; // new node int maxConnOnLevel = level == 0 ? M * 2 : M; - boolean[] mask = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel, scorer); + boolean[] mask = + selectAndLinkDiverse( + node, neighbors, candidates, maxConnOnLevel, scorer, outOfOrderInsertion); // Link the selected nodes to the new node, and the new node to the selected nodes (again // applying diversity heuristic) @@ -386,10 +391,12 @@ private void addDiverseNeighbors( * are selected */ private boolean[] selectAndLinkDiverse( + int node, NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel, - UpdateableRandomVectorScorer scorer) + UpdateableRandomVectorScorer scorer, + boolean outOfOrderInsertion) throws IOException { boolean[] mask = new boolean[candidates.size()]; // Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic @@ -397,6 +404,9 @@ private boolean[] selectAndLinkDiverse( // compare each neighbor (in distance order) against the closer neighbors selected so far, // only adding it if it is closer to the target than to any of the other selected neighbors int cNode = candidates.nodes()[i]; + if (node == cNode) { + continue; + } float cScore = candidates.getScores(i); assert cNode <= hnsw.maxNodeId(); scorer.setScoringOrdinal(cNode); @@ -404,13 +414,17 @@ private boolean[] selectAndLinkDiverse( mask[i] = true; // here we don't need to lock, because there's no incoming link so no others is able to // discover this node such that no others will modify this neighbor array as well - neighbors.addInOrder(cNode, cScore); + if (outOfOrderInsertion) { + neighbors.addOutOfOrder(cNode, cScore); + } else { + neighbors.addInOrder(cNode, cScore); + } } } return mask; } - private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) { + static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) { scratch.clear(); int candidateCount = candidates.size(); // extract all the Neighbors from the queue into an array; these will now be diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java index a29afeb615f2..dfedb66feda1 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java @@ -20,16 +20,16 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Comparator; import java.util.List; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.HnswGraphProvider; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.internal.hppc.IntIntHashMap; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -48,8 +48,20 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger { protected final int beamWidth; protected List graphReaders = new ArrayList<>(); + protected GraphReader largestGraphReader; + private int numReaders = 0; + /** + * The maximum acceptable deletion percentage for a graph to be considered as the base graph. + * Graphs with deletion percentages above this threshold are not used for initialization as they + * may have degraded connectivity. + * + *

A value of 40 means that if more than 40% of the graph's original vectors have been deleted, + * the graph will not be selected as the base. + */ + private final int DELETE_PCT_THRESHOLD = 40; + /** Represents a vector reader that contains graph info. */ protected record GraphReader( KnnVectorsReader reader, MergeState.DocMap initDocMap, int graphSize) {} @@ -67,13 +79,13 @@ public IncrementalHnswGraphMerger( /** * Adds a reader to the graph merger if it meets the following criteria: 1. does not contain any - * deleted docs 2. is a HnswGraphProvider + * deleted vector 2. is a HnswGraphProvider */ @Override public IncrementalHnswGraphMerger addReader( KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs) throws IOException { numReaders++; - if (hasDeletes(liveDocs) || !(reader instanceof HnswGraphProvider)) { + if (!(reader instanceof HnswGraphProvider)) { return this; } HnswGraph graph = ((HnswGraphProvider) reader).getGraph(fieldInfo.name); @@ -81,24 +93,29 @@ public IncrementalHnswGraphMerger addReader( return this; } - int candidateVectorCount = 0; - switch (fieldInfo.getVectorEncoding()) { - case BYTE -> { - ByteVectorValues byteVectorValues = reader.getByteVectorValues(fieldInfo.name); - if (byteVectorValues == null) { - return this; - } - candidateVectorCount = byteVectorValues.size(); - } - case FLOAT32 -> { - FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldInfo.name); - if (vectorValues == null) { - return this; - } - candidateVectorCount = vectorValues.size(); - } + KnnVectorValues knnVectorValues = + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> reader.getByteVectorValues(fieldInfo.name); + case FLOAT32 -> reader.getFloatVectorValues(fieldInfo.name); + }; + + int candidateVectorCount = countLiveVectors(liveDocs, knnVectorValues); + int graphSize = graph.size(); + + GraphReader graphReader = new GraphReader(reader, docMap, graphSize); + + int deletePct = ((graphSize - candidateVectorCount) * 100) / graphSize; + + if (deletePct <= DELETE_PCT_THRESHOLD + && (largestGraphReader == null || candidateVectorCount > largestGraphReader.graphSize)) { + largestGraphReader = graphReader; } - graphReaders.add(new GraphReader(reader, docMap, candidateVectorCount)); + + // if graph has no deletes + if (candidateVectorCount == graphSize) { + graphReaders.add(graphReader); + } + return this; } @@ -112,11 +129,15 @@ public IncrementalHnswGraphMerger addReader( */ protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd) throws IOException { - if (graphReaders.size() == 0) { + if (largestGraphReader == null) { return HnswGraphBuilder.create( scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, maxOrd); } - graphReaders.sort(Comparator.comparingInt(GraphReader::graphSize).reversed()); + if (!graphReaders.contains(largestGraphReader)) { + graphReaders.addFirst(largestGraphReader); + } else { + graphReaders.sort(Comparator.comparingInt(GraphReader::graphSize).reversed()); + } final BitSet initializedNodes = graphReaders.size() == numReaders ? null : new FixedBitSet(maxOrd); @@ -163,6 +184,7 @@ protected final int[][] getNewOrdMapping( newDocIdToOldOrdinals[i].put(newDocId, vectorsIter.index()); } oldToNewOrdinalMap[i] = new int[graphReaders.get(i).graphSize]; + Arrays.fill(oldToNewOrdinalMap[i], -1); } KnnVectorValues.DocIndexIterator mergedVectorIterator = mergedVectorValues.iterator(); @@ -192,16 +214,21 @@ public OnHeapHnswGraph merge( return builder.build(maxOrd); } - private static boolean hasDeletes(Bits liveDocs) { + private static int countLiveVectors(Bits liveDocs, KnnVectorValues knnVectorValues) + throws IOException { if (liveDocs == null) { - return false; + return knnVectorValues.size(); } - for (int i = 0; i < liveDocs.length(); i++) { - if (!liveDocs.get(i)) { - return true; + int count = 0; + DocIdSetIterator docIdSetIterator = knnVectorValues.iterator(); + for (int doc = docIdSetIterator.nextDoc(); + doc != NO_MORE_DOCS; + doc = docIdSetIterator.nextDoc()) { + if (liveDocs.get(doc)) { + count++; } } - return false; + return count; } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java index 7dff036ddde4..511474b2f702 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java @@ -20,29 +20,86 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.SplittableRandom; +import org.apache.lucene.internal.hppc.IntArrayList; +import org.apache.lucene.internal.hppc.IntCursor; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BitSet; /** * This creates a graph builder that is initialized with the provided HnswGraph. This is useful for * merging HnswGraphs from multiple segments. * + *

The builder performs the following operations: + * + *

+ * + *

Disconnected Node Detection: A node is considered disconnected if it retains less than + * {@link #DISCONNECTED_NODE_FACTOR} of its original neighbor count from the source graph. This + * typically occurs when many of the node's neighbors were deleted documents that couldn't be + * remapped. + * * @lucene.experimental */ public final class InitializedHnswGraphBuilder extends HnswGraphBuilder { /** - * Create a new HnswGraphBuilder that is initialized with the provided HnswGraph. - * - * @param scorerSupplier the scorer to use for vectors - * @param beamWidth the number of nodes to explore in the search - * @param seed the seed for the random number generator - * @param initializerGraph the graph to initialize the new graph builder - * @param newOrdMap a mapping from the old node ordinal to the new node ordinal - * @param initializedNodes a bitset of nodes that are already initialized in the initializerGraph - * @param totalNumberOfVectors the total number of vectors in the new graph, this should include - * all vectors expected to be added to the graph in the future - * @return a new HnswGraphBuilder that is initialized with the provided HnswGraph - * @throws IOException when reading the graph fails + * Tracks which nodes have already been initialized from the source graph. These nodes will be + * skipped during subsequent {@link #addGraphNode(int)} calls to avoid duplicate processing. + */ + private final BitSet initializedNodes; + + /** + * Maps each level in the graph hierarchy to the list of node ordinals present at that level. Used + * during graph rebalancing to identify candidates for promotion to higher levels. + */ + private IntArrayList[] levelToNodes; + + /** + * The threshold factor for determining if a node is disconnected. A node is considered + * disconnected if its new neighbor count is less than {@code (old neighbor count * + * DISCONNECTED_NODE_FACTOR)}. + * + *

This helps identify nodes that have lost a significant portion of their neighbors (typically + * due to document deletions) and need additional connections to maintain graph connectivity and + * search performance. + */ + private final double DISCONNECTED_NODE_FACTOR = 0.85; + + // Tracks if the graph has deletes + private boolean hasDeletes = false; + + /** + * Creates an initialized HNSW graph builder from an existing graph. + * + *

This factory method constructs a new graph builder, initializes it with the structure from + * the provided graph (applying ordinal remapping), and returns the builder ready for additional + * operations. + * + * @param scorerSupplier provides vector similarity scoring for graph operations + * @param beamWidth the search beam width for finding neighbors during graph construction + * @param seed random seed for level assignment and node promotion during rebalancing + * @param initializerGraph the source graph to copy structure from + * @param newOrdMap maps old ordinals in the initializer graph to new ordinals in the merged + * graph; -1 indicates a deleted document that should be skipped + * @param initializedNodes bit set marking which nodes are already initialized (can be null if not + * tracking) + * @param totalNumberOfVectors the total number of vectors in the merged graph (used for + * pre-allocation) + * @return a new builder initialized with the provided graph structure + * @throws IOException if an I/O error occurs during graph initialization */ public static InitializedHnswGraphBuilder fromGraph( RandomVectorScorerSupplier scorerSupplier, @@ -53,49 +110,338 @@ public static InitializedHnswGraphBuilder fromGraph( BitSet initializedNodes, int totalNumberOfVectors) throws IOException { - return new InitializedHnswGraphBuilder( - scorerSupplier, - beamWidth, - seed, - initGraph(initializerGraph, newOrdMap, totalNumberOfVectors), - initializedNodes); + + InitializedHnswGraphBuilder builder = + new InitializedHnswGraphBuilder( + scorerSupplier, + beamWidth, + seed, + new OnHeapHnswGraph(initializerGraph.maxConn(), totalNumberOfVectors), + initializedNodes); + + builder.initializeFromGraph(initializerGraph, newOrdMap); + return builder; } + /** + * Convenience method to create a fully initialized on-heap HNSW graph without tracking + * initialized nodes. This is useful when you just need the resulting graph structure without + * planning to add additional nodes incrementally. + * + * @param initializerGraph the source graph to copy structure from + * @param newOrdMap maps old ordinals to new ordinals; -1 indicates deleted documents + * @param totalNumberOfVectors the total number of vectors in the merged graph + * @param beamWidth the search beam width for graph construction + * @param scorerSupplier provides vector similarity scoring + * @return a fully initialized on-heap HNSW graph + * @throws IOException if an I/O error occurs during graph initialization + */ public static OnHeapHnswGraph initGraph( - HnswGraph initializerGraph, int[] newOrdMap, int totalNumberOfVectors) throws IOException { - OnHeapHnswGraph hnsw = new OnHeapHnswGraph(initializerGraph.maxConn(), totalNumberOfVectors); - for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) { + HnswGraph initializerGraph, + int[] newOrdMap, + int totalNumberOfVectors, + int beamWidth, + RandomVectorScorerSupplier scorerSupplier) + throws IOException { + + InitializedHnswGraphBuilder builder = + fromGraph( + scorerSupplier, + beamWidth, + randSeed, + initializerGraph, + newOrdMap, + null, + totalNumberOfVectors); + return builder.getGraph(); + } + + private InitializedHnswGraphBuilder( + RandomVectorScorerSupplier scorerSupplier, + int beamWidth, + long seed, + OnHeapHnswGraph initializedGraph, + BitSet initializedNodes) + throws IOException { + super(scorerSupplier, beamWidth, seed, initializedGraph); + this.initializedNodes = initializedNodes; + } + + /** + * Initializes the graph from the provided initializer graph through a three-phase process: + * + *

    + *
  1. Copy the graph structure with ordinal remapping, identifying disconnected nodes + *
  2. If deletions occurred, repair disconnected nodes by finding additional neighbors + *
  3. If deletions occurred, rebalance the graph hierarchy to maintain proper level + * distribution + *
+ * + * @param initializerGraph the source graph to copy from + * @param newOrdMap ordinal mapping from old to new ordinals + * @throws IOException if an I/O error occurs during initialization + */ + private void initializeFromGraph(HnswGraph initializerGraph, int[] newOrdMap) throws IOException { + hasDeletes = false; + // Phase 1: Copy structure and identify nodes that lost too many neighbors + Map> disconnectedNodesByLevel = + copyGraphStructure(initializerGraph, newOrdMap); + + // Repair graph if it has deletes + if (hasDeletes) { + // Phase 2: Repair nodes with insufficient connections + repairDisconnectedNodes(disconnectedNodesByLevel, initializerGraph.numLevels()); + + // Phase 3: Rebalance graph to maintain proper level distribution + rebalanceGraph(); + } + } + + /** + * Copies the graph structure from the initializer graph, applying ordinal remapping and + * identifying nodes that have lost neighbors. + * + *

A node is considered disconnected if it retains less than {@link #DISCONNECTED_NODE_FACTOR} + * of its original neighbors. This happens when many neighbors were deleted documents that + * couldn't be remapped (indicated by -1 in newOrdMap). + * + *

Example: With DISCONNECTED_NODE_FACTOR = 0.9, if a node had 20 neighbors in the + * source graph but only 17 remain after remapping (17/20 = 0.85 < 0.9), it's marked as + * disconnected and will be repaired. + * + * @param initializerGraph the source graph to copy from + * @param newOrdMap maps old ordinals to new ordinals; -1 indicates deleted documents + * @return map of level to list of disconnected node ordinals at that level + * @throws IOException if an I/O error occurs during graph traversal + */ + private Map> copyGraphStructure( + HnswGraph initializerGraph, int[] newOrdMap) throws IOException { + int numLevels = initializerGraph.numLevels(); + levelToNodes = new IntArrayList[numLevels]; + Map> disconnectedNodesByLevel = new HashMap<>(numLevels); + + for (int level = numLevels - 1; level >= 0; level--) { + levelToNodes[level] = new IntArrayList(); + List disconnectedNodes = new ArrayList<>(); HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level); + while (it.hasNext()) { int oldOrd = it.nextInt(); int newOrd = newOrdMap[oldOrd]; + + // Skip deleted documents (mapped to -1) + if (newOrd == -1) { + hasDeletes = true; + continue; + } + hnsw.addNode(level, newOrd); + levelToNodes[level].add(newOrd); hnsw.trySetNewEntryNode(newOrd, level); + scorer.setScoringOrdinal(newOrd); + + // Copy neighbors NeighborArray newNeighbors = hnsw.getNeighbors(level, newOrd); initializerGraph.seek(level, oldOrd); + int oldNeighbourCount = 0; for (int oldNeighbor = initializerGraph.nextNeighbor(); oldNeighbor != NO_MORE_DOCS; oldNeighbor = initializerGraph.nextNeighbor()) { + oldNeighbourCount++; int newNeighbor = newOrdMap[oldNeighbor]; - // we will compute these scores later when we need to pop out the non-diverse nodes - newNeighbors.addOutOfOrder(newNeighbor, Float.NaN); + + // Only add neighbors that weren't deleted + if (newNeighbor != -1) { + newNeighbors.addOutOfOrder(newNeighbor, Float.NaN); + } + } + + // Mark as disconnected if node lost more than the acceptable threshold of neighbors + if (newNeighbors.size() < oldNeighbourCount * DISCONNECTED_NODE_FACTOR) { + disconnectedNodes.add(newOrd); } } + disconnectedNodesByLevel.put(level, disconnectedNodes); } - return hnsw; + return disconnectedNodesByLevel; } - private final BitSet initializedNodes; + /** + * Repairs disconnected nodes at all levels by finding additional neighbors to restore + * connectivity. + * + * @param disconnectedNodesByLevel map of level to disconnected nodes at that level + * @param numLevels total number of levels in the graph hierarchy + * @throws IOException if an I/O error occurs during repair operations + */ + private void repairDisconnectedNodes( + Map> disconnectedNodesByLevel, int numLevels) throws IOException { + for (int level = numLevels - 1; level >= 0; level--) { + fixDisconnectedNodes(disconnectedNodesByLevel.get(level), level, scorer); + } + } - public InitializedHnswGraphBuilder( - RandomVectorScorerSupplier scorerSupplier, - int beamWidth, - long seed, - OnHeapHnswGraph initializedGraph, - BitSet initializedNodes) + /** + * Fixes disconnected nodes at a specific level by performing graph searches from their existing + * neighbors to find additional connections. + * + *

For each disconnected node: + * + *

    + *
  1. Use existing neighbors as entry points for graph search + *
  2. Search the level to find candidate neighbors + *
  3. Add diverse neighbors using the HNSW heuristic selection algorithm + *
+ * + *

If a node has no neighbors at all, it cannot be repaired at this level and will rely on the + * rebalancing phase. + * + * @param disconnectedNodes list of node ordinals that need additional neighbors + * @param level the level at which to repair connections + * @param scorer vector similarity scorer for distance calculations + * @throws IOException if an I/O error occurs during search operations + */ + private void fixDisconnectedNodes( + List disconnectedNodes, int level, UpdateableRandomVectorScorer scorer) throws IOException { - super(scorerSupplier, beamWidth, seed, initializedGraph); - this.initializedNodes = initializedNodes; + if (disconnectedNodes.isEmpty()) return; + + int beamWidth = beamCandidates.k(); + GraphBuilderKnnCollector candidates = new GraphBuilderKnnCollector(beamWidth); + NeighborArray scratchArray = new NeighborArray(beamWidth, false); + + for (int node : disconnectedNodes) { + scorer.setScoringOrdinal(node); + NeighborArray existingNeighbors = hnsw.getNeighbors(level, node); + + // Only repair if node has at least one neighbor to use as entry point + if (existingNeighbors.size() > 0) { + // Use all existing neighbors as entry points for search + int[] entryPoints = new int[existingNeighbors.size()]; + System.arraycopy(existingNeighbors.nodes(), 0, entryPoints, 0, existingNeighbors.size()); + + // Search from entry points to find candidate neighbors + graphSearcher.searchLevel(candidates, scorer, level, entryPoints, hnsw, null); + popToScratch(candidates, scratchArray); + + // Add diverse neighbors using HNSW heuristic (prunes similar neighbors) + addDiverseNeighbors(level, node, scratchArray, scorer, true); + } else { + // Node has no nighbors, add connections from scratch + addConnections(node, level, scorer); + } + + // Clear for next iteration + scratchArray.clear(); + candidates.clear(); + } + } + + /** + * Rebalances the graph hierarchy by promoting nodes from lower levels to higher levels to + * maintain the expected exponential decay in level sizes according to the HNSW probabilistic + * model. + * + *

The expected number of nodes at each level follows the formula:
+ * {@code maxNodesAtLevel = totalNodes * (1/M)^level} + * + *

For each level that has fewer nodes than expected, this method randomly promotes nodes from + * the level below with probability 1/M until the target count is reached. + * + *

This rebalancing is necessary during merging graph where deletions may have disrupted the + * proper hierarchical distribution, which could degrade semantic matches quality. + * + * @throws IOException if an I/O error occurs during node promotion + */ + private void rebalanceGraph() throws IOException { + SplittableRandom random = new SplittableRandom(); + int size = hnsw.size(); + double invMaxConn = 1.0 / M; + + // Process each level starting from level 1 (level 0 always contains all nodes) + for (int level = 1; ; level++) { + + // Calculate expected number of nodes at this level + int maxNodesAtLevel = (int) (size * Math.pow(invMaxConn, level)); + if (maxNodesAtLevel <= 0) break; // Stop when expected nodes drops to zero + + int currentNodesAtLevel = 0; + + // Expand levelToNodes array if we need to create new levels + if (level >= levelToNodes.length) { + levelToNodes = ArrayUtil.growExact(levelToNodes, level + 1); + levelToNodes[level] = new IntArrayList(); + } else { + currentNodesAtLevel = levelToNodes[level].size(); + } + + // Skip if this level already has enough nodes + if (currentNodesAtLevel >= maxNodesAtLevel) continue; + + // Randomly promote nodes from the level below + Iterator it = levelToNodes[level - 1].iterator(); + + while (it.hasNext() && currentNodesAtLevel < maxNodesAtLevel) { + int node = it.next().value; + + // Promote with probability 1/M, matching HNSW's level assignment distribution + if (random.nextDouble() < invMaxConn && !hnsw.nodeExistAtLevel(level, node)) { + scorer.setScoringOrdinal(node); + hnsw.addNode(level, node); + + // If this is the first node at this level, try to make it the entry point + if (currentNodesAtLevel == 0) { + hnsw.tryPromoteNewEntryNode(node, level, hnsw.numLevels() - 1); + } else { + // Add connections for non-first nodes + addConnections(node, level, scorer); + } + + levelToNodes[level].add(node); + currentNodesAtLevel++; + } + } + } + } + + /** + * Adds connections for an existing node at a specific level in the graph hierarchy. + * + *

The process involves: + * + *

    + *
  1. Navigate down from the top level to find the closest node at the target level + *
  2. Perform a full search at the target level to find neighbors + *
  3. Add diverse neighbors using the HNSW heuristic selection + *
+ * + * @param node the node ordinal to add connections for + * @param targetLevel the level to add connections at + * @param scorer vector similarity scorer for distance calculations + * @throws IOException if an I/O error occurs during search or neighbor addition + */ + private void addConnections(int node, int targetLevel, UpdateableRandomVectorScorer scorer) + throws IOException { + + int beamWidth = beamCandidates.k(); + GraphBuilderKnnCollector candidates = new GraphBuilderKnnCollector(beamWidth); + int[] eps = {hnsw.entryNode()}; + + // Navigate down from top to target level, greedily moving toward the new node + for (int level = hnsw.numLevels() - 1; level > targetLevel; level--) { + graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null); + eps[0] = candidates.popNode(); + candidates.clear(); + } + + // Perform full search at target level to find neighbors + graphSearcher.searchLevel(candidates, scorer, targetLevel, eps, hnsw, null); + + NeighborArray scratchArray = new NeighborArray(beamWidth, false); + popToScratch(candidates, scratchArray); + + // Add diverse neighbors and establish bidirectional connections + addDiverseNeighbors(targetLevel, node, scratchArray, scorer, true); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/MergingHnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/MergingHnswGraphBuilder.java index 08366927b247..fa9d0cbfe826 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/MergingHnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/MergingHnswGraphBuilder.java @@ -99,7 +99,8 @@ public static MergingHnswGraphBuilder fromGraphs( BitSet initializedNodes) throws IOException { OnHeapHnswGraph graph = - InitializedHnswGraphBuilder.initGraph(graphs[0], ordMaps[0], totalNumberOfVectors); + InitializedHnswGraphBuilder.initGraph( + graphs[0], ordMaps[0], totalNumberOfVectors, beamWidth, scorerSupplier); return new MergingHnswGraphBuilder( scorerSupplier, beamWidth, seed, graph, graphs, ordMaps, initializedNodes); } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index adda151df93b..3165e90bc55d 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -151,13 +151,15 @@ public void addNode(int level, int node) { graph = ArrayUtil.grow(graph, node + 1); } - assert graph[node] == null || graph[node].length > level - : "node must be inserted from the top level"; + assert graph[node] == null || graph[node].length >= level + : "node must be inserted from the top level: "; if (graph[node] == null) { - graph[node] = - new NeighborArray[level + 1]; // assumption: we always call this function from top level + graph[node] = new NeighborArray[level + 1]; size.incrementAndGet(); + } else if (graph[node].length <= level) { + graph[node] = ArrayUtil.growExact(graph[node], level + 1); } + if (level == 0) { graph[node][level] = new NeighborArray( @@ -216,6 +218,10 @@ public int maxConn() { return nsize - 1; } + public boolean nodeExistAtLevel(int level, int node) { + return graph[node] != null && graph[node].length > level; + } + /** * Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th * level diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 5efe13ca6714..f72d605aa1cc 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -206,6 +206,77 @@ public void testRandomReadWriteAndMerge() throws IOException { } } + @SuppressWarnings("unchecked") + public void testGraphMergeWithDeletes() throws IOException { + + int M = 4; + int beamWidth = 20; + String vectorFieldName = "vec1"; + int numVectors = random().nextInt(1000); + int deletionProbaility = random().nextInt(100); + int dim = random().nextInt(64) + 1; + if (dim % 2 == 1) { + dim++; + } + KnnVectorValues vectors = vectorValues(numVectors, dim); + int deleteCount = 0; + + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = + new IndexWriterConfig() + .setCodec( + TestUtil.alwaysKnnVectorsFormat(new Lucene99HnswVectorsFormat(M, beamWidth, 0))) + // set a random merge policy + .setMergePolicy(newMergePolicy(random())); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + switch (vectors.getEncoding()) { + case BYTE -> { + doc.add( + knnVectorField( + vectorFieldName, + (T) ((ByteVectorValues) vectors).vectorValue(i), + similarityFunction)); + } + case FLOAT32 -> { + doc.add( + knnVectorField( + vectorFieldName, + (T) ((FloatVectorValues) vectors).vectorValue(i), + similarityFunction)); + } + } + doc.add(new StringField("id", Integer.toString(i), Field.Store.NO)); + w.addDocument(doc); + } + w.commit(); + + for (int d = 0; d < numVectors; d++) { + if (random().nextInt(100) < deletionProbaility) { + deleteCount++; + w.deleteDocuments(new Term("id", Integer.toString(d))); + } + } + w.commit(); + w.forceMerge(1); + } + + try (IndexReader reader = DirectoryReader.open(dir)) { + for (LeafReaderContext ctx : reader.leaves()) { + HnswGraph graphValues = + ((Lucene99HnswVectorsReader) + ((CodecReader) ctx.reader()) + .getVectorReader() + .unwrapReaderForField(vectorFieldName)) + .getGraph(vectorFieldName); + assertEquals(numVectors - deleteCount, graphValues.size()); + } + } + } + } + @SuppressWarnings("unchecked") private T vectorValue(KnnVectorValues vectors, int ord) throws IOException { switch (vectors.getEncoding()) { @@ -548,12 +619,13 @@ public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws int initializerSize = random().nextInt(5, totalSize); int docIdOffset = 0; int dim = atLeast(10); + int beamWidth = 30; long seed = random().nextLong(); KnnVectorValues initializerVectors = vectorValues(initializerSize, dim); RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors); HnswGraphBuilder initializerBuilder = - HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed); + HnswGraphBuilder.create(initialscorerSupplier, 10, beamWidth, seed); OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size()); KnnVectorValues finalVectorValues = @@ -568,7 +640,11 @@ public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws // another graph to do the assertion OnHeapHnswGraph graphAfterInit = InitializedHnswGraphBuilder.initGraph( - initializerGraph, initializerOrdMap, initializerGraph.size()); + initializerGraph, + initializerOrdMap, + initializerGraph.size(), + beamWidth, + initialscorerSupplier); HnswGraphBuilder finalBuilder = InitializedHnswGraphBuilder.fromGraph( diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java index 082079deb8f3..e11331b9db1d 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java @@ -33,16 +33,6 @@ public void testNoGrowth() { expectThrows(IllegalStateException.class, () -> graph.addNode(1, 100)); } - /* AssertionError will be thrown if we add a node not from top most level, - (likely NPE will be thrown in prod) */ - public void testAddLevelOutOfOrder() { - OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1); - graph.addNode(0, 0); - if (TEST_ASSERTS_ENABLED) { - expectThrows(AssertionError.class, () -> graph.addNode(1, 0)); - } - } - /* assert exception will be thrown when we call getNodeOnLevel for an incomplete graph */ public void testIncompleteGraphThrow() { OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);