Skip to content

Commit ceec058

Browse files
committed
Refactor around NeighborArray (#12910)
1 parent 1a4c853 commit ceec058

File tree

12 files changed

+139
-113
lines changed

12 files changed

+139
-113
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ Improvements
3131

3232
* GITHUB#12812: Avoid overflows and false negatives in int slice buffer filled-with-zeros assertion. (Stefan Vodita)
3333

34+
* GITHUB#12910: Refactor around NeighborArray to make it more self-contained. (Patrick Zhai)
35+
3436
Optimizations
3537
---------------------
3638
(No changes)

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ private OnHeapHnswGraph writeGraph(
297297
int size = neighbors.size();
298298
vectorIndex.writeInt(size);
299299
// Destructively modify; it's ok we are discarding it after this
300-
int[] nnodes = neighbors.node();
300+
int[] nnodes = neighbors.nodes();
301301
Arrays.sort(nnodes, 0, size);
302302
for (int i = 0; i < size; i++) {
303303
int nnode = nnodes[i];

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ private void reconstructAndWriteNeigbours(
369369
vectorIndex.writeInt(size);
370370

371371
// Destructively modify; it's ok we are discarding it after this
372-
int[] nnodes = neighbors.node();
372+
int[] nnodes = neighbors.nodes();
373373
for (int i = 0; i < size; i++) {
374374
nnodes[i] = oldToNewMap[nnodes[i]];
375375
}
@@ -506,7 +506,7 @@ private void writeGraph(OnHeapHnswGraph graph) throws IOException {
506506
int size = neighbors.size();
507507
vectorIndex.writeInt(size);
508508
// Destructively modify; it's ok we are discarding it after this
509-
int[] nnodes = neighbors.node();
509+
int[] nnodes = neighbors.nodes();
510510
Arrays.sort(nnodes, 0, size);
511511
for (int i = 0; i < size; i++) {
512512
int nnode = nnodes[i];

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ private void reconstructAndWriteNeigbours(
396396
vectorIndex.writeVInt(size);
397397

398398
// Destructively modify; it's ok we are discarding it after this
399-
int[] nnodes = neighbors.node();
399+
int[] nnodes = neighbors.nodes();
400400
for (int i = 0; i < size; i++) {
401401
nnodes[i] = oldToNewMap[nnodes[i]];
402402
}
@@ -556,7 +556,7 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException {
556556
long offsetStart = vectorIndex.getFilePointer();
557557
vectorIndex.writeVInt(size);
558558
// Destructively modify; it's ok we are discarding it after this
559-
int[] nnodes = neighbors.node();
559+
int[] nnodes = neighbors.nodes();
560560
Arrays.sort(nnodes, 0, size);
561561
// Now that we have sorted, do delta encoding to minimize the required bits to store the
562562
// information

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ private void reconstructAndWriteNeighbours(NeighborArray neighbors, int[] oldToN
319319
vectorIndex.writeVInt(size);
320320

321321
// Destructively modify; it's ok we are discarding it after this
322-
int[] nnodes = neighbors.node();
322+
int[] nnodes = neighbors.nodes();
323323
for (int i = 0; i < size; i++) {
324324
nnodes[i] = oldToNewMap[nnodes[i]];
325325
}
@@ -415,7 +415,7 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException {
415415
long offsetStart = vectorIndex.getFilePointer();
416416
vectorIndex.writeVInt(size);
417417
// Destructively modify; it's ok we are discarding it after this
418-
int[] nnodes = neighbors.node();
418+
int[] nnodes = neighbors.nodes();
419419
Arrays.sort(nnodes, 0, size);
420420
// Now that we have sorted, do delta encoding to minimize the required bits to store the
421421
// information

lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ void graphSeek(HnswGraph graph, int level, int targetNode) {
201201
nodeBuffer = new int[neighborArray.size()];
202202
}
203203
size = neighborArray.size();
204-
if (size >= 0) System.arraycopy(neighborArray.node, 0, nodeBuffer, 0, size);
204+
if (size >= 0) System.arraycopy(neighborArray.nodes(), 0, nodeBuffer, 0, size);
205205
} finally {
206206
neighborArray.rwlock.readLock().unlock();
207207
}

lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java

Lines changed: 5 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -310,15 +310,11 @@ private void addDiverseNeighbors(int level, int node, NeighborArray candidates)
310310
if (mask[i] == false) {
311311
continue;
312312
}
313-
int nbr = candidates.node[i];
313+
int nbr = candidates.nodes()[i];
314314
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
315315
nbrsOfNbr.rwlock.writeLock().lock();
316316
try {
317-
nbrsOfNbr.addOutOfOrder(node, candidates.score[i]);
318-
if (nbrsOfNbr.size() > maxConnOnLevel) {
319-
int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr);
320-
nbrsOfNbr.removeIndex(indexToRemove);
321-
}
317+
nbrsOfNbr.addAndEnsureDiversity(node, candidates.scores()[i], nbr, scorerSupplier);
322318
} finally {
323319
nbrsOfNbr.rwlock.writeLock().unlock();
324320
}
@@ -336,8 +332,8 @@ private boolean[] selectAndLinkDiverse(
336332
for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
337333
// compare each neighbor (in distance order) against the closer neighbors selected so far,
338334
// only adding it if it is closer to the target than to any of the other selected neighbors
339-
int cNode = candidates.node[i];
340-
float cScore = candidates.score[i];
335+
int cNode = candidates.nodes()[i];
336+
float cScore = candidates.scores()[i];
341337
assert cNode <= hnsw.maxNodeId();
342338
if (diversityCheck(cNode, cScore, neighbors)) {
343339
mask[i] = true;
@@ -371,70 +367,14 @@ private boolean diversityCheck(int candidate, float score, NeighborArray neighbo
371367
throws IOException {
372368
RandomVectorScorer scorer = scorerSupplier.scorer(candidate);
373369
for (int i = 0; i < neighbors.size(); i++) {
374-
float neighborSimilarity = scorer.score(neighbors.node[i]);
370+
float neighborSimilarity = scorer.score(neighbors.nodes()[i]);
375371
if (neighborSimilarity >= score) {
376372
return false;
377373
}
378374
}
379375
return true;
380376
}
381377

382-
/**
383-
* Find first non-diverse neighbour among the list of neighbors starting from the most distant
384-
* neighbours
385-
*/
386-
private int findWorstNonDiverse(NeighborArray neighbors, int nodeOrd) throws IOException {
387-
RandomVectorScorer scorer = scorerSupplier.scorer(nodeOrd);
388-
int[] uncheckedIndexes = neighbors.sort(scorer);
389-
if (uncheckedIndexes == null) {
390-
// all nodes are checked, we will directly return the most distant one
391-
return neighbors.size() - 1;
392-
}
393-
int uncheckedCursor = uncheckedIndexes.length - 1;
394-
for (int i = neighbors.size() - 1; i > 0; i--) {
395-
if (uncheckedCursor < 0) {
396-
// no unchecked node left
397-
break;
398-
}
399-
if (isWorstNonDiverse(i, neighbors, uncheckedIndexes, uncheckedCursor)) {
400-
return i;
401-
}
402-
if (i == uncheckedIndexes[uncheckedCursor]) {
403-
uncheckedCursor--;
404-
}
405-
}
406-
return neighbors.size() - 1;
407-
}
408-
409-
private boolean isWorstNonDiverse(
410-
int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor)
411-
throws IOException {
412-
float minAcceptedSimilarity = neighbors.score[candidateIndex];
413-
RandomVectorScorer scorer = scorerSupplier.scorer(neighbors.node[candidateIndex]);
414-
if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
415-
// the candidate itself is unchecked
416-
for (int i = candidateIndex - 1; i >= 0; i--) {
417-
float neighborSimilarity = scorer.score(neighbors.node[i]);
418-
// candidate node is too similar to node i given its score relative to the base node
419-
if (neighborSimilarity >= minAcceptedSimilarity) {
420-
return true;
421-
}
422-
}
423-
} else {
424-
// else we just need to make sure candidate does not violate diversity with the (newly
425-
// inserted) unchecked nodes
426-
assert candidateIndex > uncheckedIndexes[uncheckedCursor];
427-
for (int i = uncheckedCursor; i >= 0; i--) {
428-
float neighborSimilarity = scorer.score(neighbors.node[uncheckedIndexes[i]]);
429-
// candidate node is too similar to node i given its score relative to the base node
430-
if (neighborSimilarity >= minAcceptedSimilarity) {
431-
return true;
432-
}
433-
}
434-
}
435-
return false;
436-
}
437-
438378
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
439379
double randDouble;
440380
do {

lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ void graphSeek(HnswGraph graph, int level, int targetNode) {
308308
@Override
309309
int graphNextNeighbor(HnswGraph graph) {
310310
if (++upto < cur.size()) {
311-
return cur.node[upto];
311+
return cur.nodes()[upto];
312312
}
313313
return NO_MORE_DOCS;
314314
}

0 commit comments

Comments
 (0)