Skip to content

Commit 7a68c6c

Browse files
Implement NodeQueue#pushAll and AbstractLongHeap#addAll (#415)
Adds bulk add operations to the AbstractLongHeap and to the NodeQueue to reduce comparisons required when adding many elements at once.
1 parent 41ce85d commit 7a68c6c

File tree

5 files changed

+200
-2
lines changed

5 files changed

+200
-2
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@
2828
import io.github.jbellis.jvector.util.AbstractLongHeap;
2929
import io.github.jbellis.jvector.util.BoundedLongHeap;
3030
import io.github.jbellis.jvector.util.NumericUtils;
31+
import java.util.PrimitiveIterator;
3132
import org.agrona.collections.Int2ObjectHashMap;
3233

33-
import java.util.Arrays;
34-
3534
import static java.lang.Math.min;
3635

3736
/**
@@ -90,6 +89,16 @@ public boolean push(int newNode, float newScore) {
9089
return heap.push(encode(newNode, newScore));
9190
}
9291

92+
/**
93+
* Encodes then adds all elements from the given iterator to this heap, in bulk.
94+
*
95+
* @param nodeScoreIterator the node and score pairs to add
96+
* @param count the number of elements to add
97+
*/
98+
public void pushAll(NodeScoreIterator nodeScoreIterator, int count) {
99+
heap.pushAll(new NodeScoreIteratorConverter(nodeScoreIterator, this), count);
100+
}
101+
93102
/**
94103
* Encodes the node ID and its similarity score as long. If two scores are equals,
95104
* the smaller node ID wins.
@@ -260,6 +269,18 @@ public interface NodeConsumer {
260269
void accept(int node, float score);
261270
}
262271

272+
/** Iterator over node and score pairs. */
273+
public interface NodeScoreIterator {
274+
/** @return true if there are more elements */
275+
boolean hasNext();
276+
277+
/** @return the next node id */
278+
int nextNode();
279+
280+
/** @return the next node score and advance the iterator */
281+
float nextScore();
282+
}
283+
263284
/**
264285
* Copies the other NodeQueue to this one. If its order (MIN_HEAP or MAX_HEAP) is the same as this,
265286
* it is copied verbatim. If it differs, every lement is re-inserted into this.
@@ -274,4 +295,28 @@ public void copyFrom(NodeQueue other) {
274295
other.foreach(this::push);
275296
}
276297
}
298+
299+
/**
300+
* Converts a NodeScoreIterator to a PrimitiveIterator.OfLong by encoding the node and score as a long.
301+
*/
302+
private static class NodeScoreIteratorConverter implements PrimitiveIterator.OfLong {
303+
private final NodeScoreIterator it;
304+
private final NodeQueue queue;
305+
306+
public NodeScoreIteratorConverter(NodeScoreIterator it, NodeQueue queue) {
307+
this.it = it;
308+
this.queue = queue;
309+
}
310+
311+
@Override
312+
public boolean hasNext() {
313+
return it.hasNext();
314+
}
315+
316+
@Override
317+
public long nextLong() {
318+
// Call to nextScore() advances the iterator
319+
return queue.encode(it.nextNode(), it.nextScore());
320+
}
321+
}
277322
}

jvector-base/src/main/java/io/github/jbellis/jvector/util/AbstractLongHeap.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
package io.github.jbellis.jvector.util;
2626

2727
import io.github.jbellis.jvector.annotations.VisibleForTesting;
28+
import java.util.PrimitiveIterator;
2829

2930
/**
3031
* A min heap that stores longs; a primitive priority queue that like all priority queues maintains
@@ -64,6 +65,14 @@ public AbstractLongHeap(int initialSize) {
6465
*/
6566
public abstract boolean push(long element);
6667

68+
/**
69+
* Adds all elements from the given iterator to this heap, in bulk.
70+
*
71+
* @param elements the elements to add
72+
* @param elementsSize the number of elements to add
73+
*/
74+
public abstract void pushAll(PrimitiveIterator.OfLong elements, int elementsSize);
75+
6776
protected long add(long element) {
6877
size++;
6978
if (size == heap.length) {
@@ -74,6 +83,38 @@ protected long add(long element) {
7483
return heap[1];
7584
}
7685

86+
/**
87+
* Bulk-adds all elements from the given iterator to this heap, then re-heapifies
88+
* in O(n) time (Floyd's build-heap). For a proof explaining the linear time
89+
* complexity, see <a href="https://stackoverflow.com/a/18742428">this stackoverflow answer</a>.
90+
*
91+
* @param elements the elements to add
92+
* @param elementsSize the number of elements to add
93+
*/
94+
protected void addAll(PrimitiveIterator.OfLong elements, int elementsSize) {
95+
if (!elements.hasNext()) {
96+
return; // nothing to do
97+
}
98+
99+
// 1) Ensure we have enough capacity
100+
int newSize = size + elementsSize;
101+
if (newSize >= heap.length) {
102+
heap = ArrayUtil.grow(heap, newSize);
103+
}
104+
105+
// 2) Copy the new elements directly into the array
106+
while (elements.hasNext()) {
107+
heap[++size] = elements.nextLong();
108+
}
109+
110+
// 3) "Bottom-up" re-heapify:
111+
// Start from the last non-leaf node (size >>> 1) down to the root (1).
112+
// This is Floyd's build-heap algorithm.
113+
for (int i = size >>> 1; i >= 1; i--) {
114+
downHeap(i);
115+
}
116+
}
117+
77118
/**
78119
* Returns the least element of the LongHeap in constant time. It is up to the caller to verify
79120
* that the heap is not empty; no checking is done, and if no elements have been added, 0 is

jvector-base/src/main/java/io/github/jbellis/jvector/util/BoundedLongHeap.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
package io.github.jbellis.jvector.util;
2626

2727
import io.github.jbellis.jvector.annotations.VisibleForTesting;
28+
import java.util.PrimitiveIterator;
2829

2930
/**
3031
* An AbstractLongHeap with an adjustable maximum size.
@@ -67,6 +68,15 @@ public boolean push(long value) {
6768
return true;
6869
}
6970

71+
@Override
72+
public void pushAll(PrimitiveIterator.OfLong elements, int elementsSize)
73+
{
74+
if (elementsSize + size >= maxSize) {
75+
throw new IllegalArgumentException("Cannot add more elements than maxSize");
76+
}
77+
addAll(elements, elementsSize);
78+
}
79+
7080
/**
7181
* Replace the top of the heap with {@code newTop}, and enforce the heap invariant.
7282
* Should be called when the top value changes.

jvector-base/src/main/java/io/github/jbellis/jvector/util/GrowableLongHeap.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
package io.github.jbellis.jvector.util;
2626

27+
import java.util.PrimitiveIterator;
28+
2729
/**
2830
* An AbstractLongHeap that can grow in size (unbounded, except for memory and array size limits).
2931
*/
@@ -47,4 +49,10 @@ public boolean push(long element) {
4749
add(element);
4850
return true;
4951
}
52+
53+
@Override
54+
public void pushAll(PrimitiveIterator.OfLong elements, int elementsSize)
55+
{
56+
addAll(elements, elementsSize);
57+
}
5058
}

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeQueue.java

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,68 @@ public void testUnboundedQueue() {
118118
assertEquals(maxNode, nn.topNode());
119119
}
120120

121+
@Test
122+
public void testPushAllMinHeap() {
123+
// Build a NodeQueue with a GrowableLongHeap, using MIN_HEAP order
124+
NodeQueue queue = new NodeQueue(new GrowableLongHeap(2), NodeQueue.Order.MIN_HEAP);
125+
126+
// Let's prepare some node, score pairs
127+
int[] nodes = { 5, 1, 3, 2, 8 };
128+
float[] scores = { 2.2f, -1.0f, 0.5f, 2.1f, -0.9f };
129+
130+
// We'll create a TestNodeScoreIterator with these arrays
131+
TestNodeScoreIterator it = new TestNodeScoreIterator(nodes, scores);
132+
133+
// Bulk-add all pairs in one go
134+
queue.pushAll(it, nodes.length);
135+
136+
// The queue should now contain 5 elements
137+
assertEquals(5, queue.size());
138+
139+
// Because it's a MIN_HEAP, the top (root) should be the "smallest" score
140+
// We have scores: [2.2, -1.0, 0.5, 2.1, -0.9]
141+
// The minimum is -1.0. Let's see which node that corresponds to: node=1
142+
assertEquals(-1.0f, queue.topScore(), 0.000001);
143+
assertEquals(1, queue.topNode());
144+
}
145+
146+
@Test
147+
public void testPushAllMaxHeap() {
148+
// Build a NodeQueue with a GrowableLongHeap, using MAX_HEAP order
149+
NodeQueue queue = new NodeQueue(new GrowableLongHeap(2), NodeQueue.Order.MAX_HEAP);
150+
151+
// Let's prepare some node, score pairs
152+
int[] nodes = { 10, 20, 30, 40, 50 };
153+
float[] scores = { -2.5f, 1.0f, 0.0f, 1.5f, 3.0f };
154+
155+
// We'll create a TestNodeScoreIterator with these arrays
156+
TestNodeScoreIterator it = new TestNodeScoreIterator(nodes, scores);
157+
158+
// Bulk-add all pairs in one go
159+
queue.pushAll(it, nodes.length);
160+
161+
// The queue should now contain 5 elements
162+
assertEquals(5, queue.size());
163+
164+
// Because it's a MAX_HEAP, the top (root) should be the "largest" score
165+
// The largest among [-2.5, 1.0, 0.0, 1.5, 3.0] is 3.0 => node=50
166+
assertEquals(3.0f, queue.topScore(), 0.000001);
167+
assertEquals(50, queue.topNode());
168+
}
169+
170+
@Test
171+
public void testPushAllBoundedHeapExceedsCapacity() {
172+
assertThrows(IllegalArgumentException.class, () -> {
173+
NodeQueue queue = new NodeQueue(new BoundedLongHeap(2), NodeQueue.Order.MAX_HEAP);
174+
queue.pushAll(new TestNodeScoreIterator(new int[] { 1, 2, 3 }, new float[] { 1, 2, 3 }), 3);
175+
});
176+
NodeQueue queue = new NodeQueue(new BoundedLongHeap(2), NodeQueue.Order.MAX_HEAP);
177+
queue.push(1, 1);
178+
assertThrows(IllegalArgumentException.class, () -> {
179+
queue.pushAll(new TestNodeScoreIterator(new int[] { 1, 2 }, new float[] { 1, 2 }), 2);
180+
});
181+
}
182+
121183
@Test
122184
public void testInvalidArguments() {
123185
assertThrows(IllegalArgumentException.class, () -> new NodeQueue(new GrowableLongHeap(0), NodeQueue.Order.MIN_HEAP));
@@ -127,4 +189,36 @@ public void testInvalidArguments() {
127189
public void testToString() {
128190
assertEquals("Nodes[0]", new NodeQueue(new GrowableLongHeap(2), NodeQueue.Order.MIN_HEAP).toString());
129191
}
192+
193+
/**
194+
* Simple iterator that yields a fixed array of (node, score) pairs
195+
* for testing the pushAll method.
196+
*/
197+
static class TestNodeScoreIterator implements NodeQueue.NodeScoreIterator {
198+
private final int[] nodes;
199+
private final float[] scores;
200+
private int index = 0;
201+
202+
TestNodeScoreIterator(int[] nodes, float[] scores) {
203+
assert nodes.length == scores.length;
204+
this.nodes = nodes;
205+
this.scores = scores;
206+
}
207+
208+
@Override
209+
public boolean hasNext() {
210+
return index < nodes.length;
211+
}
212+
213+
@Override
214+
public int nextNode() {
215+
return nodes[index];
216+
}
217+
218+
@Override
219+
public float nextScore() {
220+
return scores[index++];
221+
}
222+
}
223+
130224
}

0 commit comments

Comments
 (0)