Skip to content

Commit c243698

Browse files
authored
Refactoring HNSWGraphBuilder's API and adding more comments about concurrency (#15184)
1 parent 13a7e1e commit c243698

File tree

7 files changed

+65
-66
lines changed

7 files changed

+65
-66
lines changed

lucene/CHANGES.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ New Features
134134

135135
Improvements
136136
---------------------
137-
# GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch)
137+
* GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch)
138+
139+
* GITHUB#15184: Refactoring internal HNSWGraphBuilder's APIs and avoid creating new scorer for each call (Patrick Zhai)
138140

139141
Optimizations
140142
---------------------

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
import org.apache.lucene.util.hnsw.NeighborArray;
6060
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
6161
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
62-
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
6362
import org.apache.lucene.util.packed.DirectMonotonicWriter;
6463

6564
/**
@@ -586,7 +585,6 @@ private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
586585
private int lastDocID = -1;
587586
private int node = 0;
588587
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
589-
private UpdateableRandomVectorScorer scorer;
590588

591589
@SuppressWarnings("unchecked")
592590
static FieldWriter<?> create(
@@ -642,7 +640,6 @@ static FieldWriter<?> create(
642640
(List<float[]>) flatFieldVectorsWriter.getVectors(),
643641
fieldInfo.getVectorDimension()));
644642
};
645-
this.scorer = scorerSupplier.scorer();
646643
hnswGraphBuilder =
647644
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
648645
hnswGraphBuilder.setInfoStream(infoStream);
@@ -658,8 +655,7 @@ public void addValue(int docID, T vectorValue) throws IOException {
658655
+ "\" appears more than once in this document (only one value is allowed per field)");
659656
}
660657
flatFieldVectorsWriter.addValue(docID, vectorValue);
661-
scorer.setScoringOrdinal(node);
662-
hnswGraphBuilder.addGraphNode(node, scorer);
658+
hnswGraphBuilder.addGraphNode(node);
663659
node++;
664660
lastDocID = docID;
665661
}

lucene/core/src/java/org/apache/lucene/util/hnsw/HnswBuilder.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.lucene.util.hnsw;
1919

2020
import java.io.IOException;
21+
import org.apache.lucene.internal.hppc.IntHashSet;
2122
import org.apache.lucene.util.InfoStream;
2223

2324
/**
@@ -34,16 +35,21 @@ public interface HnswBuilder {
3435
*/
3536
OnHeapHnswGraph build(int maxOrd) throws IOException;
3637

37-
/** Inserts a doc with vector value to the graph */
38+
/** Inserts a doc with a vector value to the graph */
3839
void addGraphNode(int node) throws IOException;
3940

41+
/**
42+
* Inserts a doc with a vector value to the graph, searching on level 0 with provided entry points
43+
*/
44+
void addGraphNode(int node, IntHashSet eps) throws IOException;
45+
4046
/** Set info-stream to output debugging information */
4147
void setInfoStream(InfoStream infoStream);
4248

4349
OnHeapHnswGraph getGraph();
4450

4551
/**
46-
* Once this method is called no further updates to the graph are accepted (addGraphNode will
52+
* Once this method is called, no further updates to the graph are accepted (addGraphNode will
4753
* throw IllegalStateException). Final modifications to the graph (eg patching up disconnected
4854
* components, re-ordering node ids for better delta compression) may be triggered, so callers
4955
* should expect this call to take some time.

lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.concurrent.Callable;
2626
import java.util.concurrent.atomic.AtomicInteger;
2727
import java.util.concurrent.locks.Lock;
28+
import org.apache.lucene.internal.hppc.IntHashSet;
2829
import org.apache.lucene.search.TaskExecutor;
2930
import org.apache.lucene.util.BitSet;
3031
import org.apache.lucene.util.FixedBitSet;
@@ -98,6 +99,11 @@ public void addGraphNode(int node) throws IOException {
9899
throw new UnsupportedOperationException("This builder is for merge only");
99100
}
100101

102+
@Override
103+
public void addGraphNode(int node, IntHashSet eps) throws IOException {
104+
throw new UnsupportedOperationException("This builder is for merge only");
105+
}
106+
101107
@Override
102108
public void setInfoStream(InfoStream infoStream) {
103109
this.infoStream = infoStream;
@@ -142,7 +148,6 @@ private static final class ConcurrentMergeWorker extends HnswGraphBuilder {
142148

143149
private final BitSet initializedNodes;
144150
private int batchSize = DEFAULT_BATCH_SIZE;
145-
private final UpdateableRandomVectorScorer scorer;
146151

147152
private ConcurrentMergeWorker(
148153
RandomVectorScorerSupplier scorerSupplier,
@@ -163,7 +168,6 @@ private ConcurrentMergeWorker(
163168
new NeighborQueue(beamWidth, true), hnswLock, new FixedBitSet(hnsw.maxNodeId() + 1)));
164169
this.workProgress = workProgress;
165170
this.initializedNodes = initializedNodes;
166-
this.scorer = scorerSupplier.scorer();
167171
}
168172

169173
/**
@@ -192,21 +196,12 @@ private int getStartPos(int maxOrd) {
192196
}
193197
}
194198

195-
@Override
196-
public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException {
197-
if (initializedNodes != null && initializedNodes.get(node)) {
198-
return;
199-
}
200-
super.addGraphNode(node, scorer);
201-
}
202-
203199
@Override
204200
public void addGraphNode(int node) throws IOException {
205201
if (initializedNodes != null && initializedNodes.get(node)) {
206202
return;
207203
}
208-
scorer.setScoringOrdinal(node);
209-
addGraphNode(node, scorer);
204+
super.addGraphNode(node);
210205
}
211206
}
212207

lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
/**
4040
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
4141
* hyper-parameters.
42+
*
43+
* <p>Thread-safety: This class is NOT thread safe, it cannot be shared across threads, however, it
44+
* IS safe for multiple HnswGraphBuilder to build the same graph, if the graph's size is known in
45+
* the beginning (like when doing merge)
4246
*/
4347
public class HnswGraphBuilder implements HnswBuilder {
4448

@@ -64,7 +68,7 @@ public class HnswGraphBuilder implements HnswBuilder {
6468
private final double ml;
6569

6670
private final SplittableRandom random;
67-
protected final RandomVectorScorerSupplier scorerSupplier;
71+
private final UpdateableRandomVectorScorer scorer;
6872
private final HnswGraphSearcher graphSearcher;
6973
private final GraphBuilderKnnCollector entryCandidates; // for upper levels of graph search
7074
private final GraphBuilderKnnCollector
@@ -144,8 +148,8 @@ protected HnswGraphBuilder(
144148
throw new IllegalArgumentException("beamWidth must be positive");
145149
}
146150
this.M = hnsw.maxConn();
147-
this.scorerSupplier =
148-
Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null");
151+
this.scorer =
152+
Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null").scorer();
149153
// normalization factor for level generation; currently not configurable
150154
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
151155
this.random = new SplittableRandom(seed);
@@ -196,10 +200,8 @@ protected void addVectors(int minOrd, int maxOrd) throws IOException {
196200
if (infoStream.isEnabled(HNSW_COMPONENT)) {
197201
infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
198202
}
199-
UpdateableRandomVectorScorer scorer = scorerSupplier.scorer();
200203
for (int node = minOrd; node < maxOrd; node++) {
201-
scorer.setScoringOrdinal(node);
202-
addGraphNode(node, scorer);
204+
addGraphNode(node);
203205
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
204206
t = printGraphBuildStatus(node, start, t);
205207
}
@@ -210,10 +212,24 @@ private void addVectors(int maxOrd) throws IOException {
210212
addVectors(0, maxOrd);
211213
}
212214

213-
public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException {
214-
addGraphNodeInternal(node, scorer, null);
215-
}
216-
215+
/**
216+
* Note: this implementation is thread safe when the graph size is fixed (e.g. when merging) The
217+
* process of adding a node is roughly: 1. Add the node to all levels from top to the bottom, but
218+
* do not connect it to any other node, nor try to promote itself to an entry node before the
219+
* connection is done. (Unless the graph is empty and this is the first node, in that case we set
220+
* the entry node and return) 2. Do the search from top to bottom, remember all the possible
221+
* neighbours on each level the node is on. 3. Add the neighbor to the node from bottom to top
222+
* level. When adding the neighbour, we always add all the outgoing links first before adding an
223+
* incoming link such that when a search visits this node, it can always find a way out 4. If the
224+
* node has a level that is less or equal to the graph's max level, then we're done here. If the
225+
* node has a level larger than the graph's max level, then we need to promote the node as the
226+
* entry node. If, while we add the node to the graph, the entry node has changed (which means the
227+
* graph level has changed as well), we need to reinsert the node to the newly introduced levels
228+
* (repeating step 2,3 for new levels) and again try to promote the node to entry node.
229+
*
230+
* @param eps0 If specified, we will use it as the entry points of search on level 0, is useful
231+
* when you have some prior knowledge, e.g. in {@link MergingHnswGraphBuilder}
232+
*/
217233
private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer, IntHashSet eps0)
218234
throws IOException {
219235
if (frozen) {
@@ -224,7 +240,8 @@ private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer,
224240
for (int level = nodeLevel; level >= 0; level--) {
225241
hnsw.addNode(level, node);
226242
}
227-
// then promote itself as entry node if entry node is not set
243+
// then promote itself as entry node if entry node is not set (this is the first ever node of
244+
// the graph)
228245
if (hnsw.trySetNewEntryNode(node, nodeLevel)) {
229246
return;
230247
}
@@ -235,8 +252,12 @@ private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer,
235252
int curMaxLevel;
236253
do {
237254
curMaxLevel = hnsw.numLevels() - 1;
238-
// NOTE: the entry node and max level may not be paired, but because we get the level first
255+
// NOTE: the entry node and max level are not retrieved synchronously, which could lead to a
256+
// situation where
257+
// the entry node's level is different from the graph's max level, but because we get the
258+
// level first,
239259
// we ensure that the entry node we get later will always exist on the curMaxLevel
260+
// e.g., curMaxLevel <= entryNode.level
240261
int[] eps = new int[] {hnsw.entryNode()};
241262

242263
// we first do the search from top to bottom
@@ -271,15 +292,21 @@ private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer,
271292
}
272293
lowestUnsetLevel += scratchPerLevel.length;
273294
assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1;
274-
if (lowestUnsetLevel > nodeLevel) {
295+
if (lowestUnsetLevel == nodeLevel + 1) {
296+
// we have already set all the levels we need for this node
275297
return;
276298
}
277299
assert lowestUnsetLevel == curMaxLevel + 1 && nodeLevel > curMaxLevel;
300+
// The node's level is higher than the graph's max level, so we need to
301+
// try to promote this node as the graph's entry node
278302
if (hnsw.tryPromoteNewEntryNode(node, nodeLevel, curMaxLevel)) {
279303
return;
280304
}
305+
// If we're not able to promote, it means the graph must have already changed
306+
// and has a new max level and some other entry node
281307
if (hnsw.numLevels() == curMaxLevel + 1) {
282-
// This should never happen if all the calculations are correct
308+
// This is an impossible situation, if happens, then something above is
309+
// not hold
283310
throw new IllegalStateException(
284311
"We're not able to promote node "
285312
+ node
@@ -294,31 +321,12 @@ private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer,
294321

295322
@Override
296323
public void addGraphNode(int node) throws IOException {
297-
/*
298-
* Note: this implementation is thread safe when graph size is fixed (e.g. when merging)
299-
* The process of adding a node is roughly:
300-
* 1. Add the node to all level from top to the bottom, but do not connect it to any other node,
301-
* nor try to promote itself to an entry node before the connection is done. (Unless the graph is empty
302-
* and this is the first node, in that case we set the entry node and return)
303-
* 2. Do the search from top to bottom, remember all the possible neighbours on each level the node
304-
* is on.
305-
* 3. Add the neighbor to the node from bottom to top level, when adding the neighbour,
306-
* we always add all the outgoing links first before adding incoming link such that
307-
* when a search visits this node, it can always find a way out
308-
* 4. If the node has level that is less or equal to graph level, then we're done here.
309-
* If the node has level larger than graph level, then we need to promote the node
310-
* as the entry node. If, while we add the node to the graph, the entry node has changed
311-
* (which means the graph level has changed as well), we need to reinsert the node
312-
* to the newly introduced levels (repeating step 2,3 for new levels) and again try to
313-
* promote the node to entry node.
314-
*/
315-
UpdateableRandomVectorScorer scorer = scorerSupplier.scorer();
316324
scorer.setScoringOrdinal(node);
317325
addGraphNodeInternal(node, scorer, null);
318326
}
319327

320-
public void addGraphNodeWithEps(int node, IntHashSet eps0) throws IOException {
321-
UpdateableRandomVectorScorer scorer = scorerSupplier.scorer();
328+
@Override
329+
public void addGraphNode(int node, IntHashSet eps0) throws IOException {
322330
scorer.setScoringOrdinal(node);
323331
addGraphNodeInternal(node, scorer, eps0);
324332
}
@@ -486,7 +494,6 @@ private boolean connectComponents(int level) throws IOException {
486494
// while linking
487495
GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(2);
488496
int[] eps = new int[1];
489-
UpdateableRandomVectorScorer scorer = scorerSupplier.scorer();
490497
for (Component c : components) {
491498
if (c != c0) {
492499
if (c.start() == NO_MORE_DOCS) {

lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,6 @@ public InitializedHnswGraphBuilder(
9898
this.initializedNodes = initializedNodes;
9999
}
100100

101-
@Override
102-
public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException {
103-
if (initializedNodes.get(node)) {
104-
return;
105-
}
106-
super.addGraphNode(node, scorer);
107-
}
108-
109101
@Override
110102
public void addGraphNode(int node) throws IOException {
111103
if (initializedNodes.get(node)) {

lucene/core/src/java/org/apache/lucene/util/hnsw/MergingHnswGraphBuilder.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ private MergingHnswGraphBuilder(
8484
* @param ordMaps the ordinal maps for the graphs
8585
* @param totalNumberOfVectors the total number of vectors in the new graph, this should include
8686
* all vectors expected to be added to the graph in the future
87-
* @param initializedNodes the nodes will be initialized through the merging
87+
* @param initializedNodes the nodes will be initialized through the merging, if null, all nodes
88+
* should be already initialized after {@link #updateGraph(HnswGraph, int[])} being called
8889
* @return a new HnswGraphBuilder that is initialized with the provided HnswGraph
8990
* @throws IOException when reading the graph fails
9091
*/
@@ -172,7 +173,7 @@ private void updateGraph(HnswGraph gS, int[] ordMapS) throws IOException {
172173
}
173174
}
174175
}
175-
addGraphNodeWithEps(ordMapS[u], eps);
176+
addGraphNode(ordMapS[u], eps);
176177
}
177178
}
178179
}

0 commit comments

Comments
 (0)