Skip to content

Commit 9eb18a6

Browse files
CNDB-13417: Reduce object alloc. on brute force hybrid ann path (#1643)
### What is the issue Relates to riptano/cndb#13417 (the issue has multiple PRs) ### What does this PR fix and why was it fixed In hybrid ANN search resulting in brute force row sorting, we see many object allocations per row materialized (so O(n) space complexity), and this creates memory pressure leading to reduced performance. This PR reduces object allocation by using jvector's `NodeQueue` data structure that encodes an int and a float into a long and then sorts based on the float. I also renamed the IntIntPairArray because I needed more specialized methods. - **Rename IntIntPairArray to SegmentRowIdOrdinalPairs** - **CNDB-13417: Reduce object alloc. on brute force hybrid ann path**
1 parent cb3de82 commit 9eb18a6

File tree

7 files changed

+335
-152
lines changed

7 files changed

+335
-152
lines changed

src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@
2222
import java.util.ArrayList;
2323
import java.util.Collections;
2424
import java.util.List;
25-
import java.util.function.Consumer;
2625

2726
import com.google.common.annotations.VisibleForTesting;
2827
import com.google.common.base.MoreObjects;
29-
import com.google.common.util.concurrent.Runnables;
3028
import org.slf4j.Logger;
3129
import org.slf4j.LoggerFactory;
3230

31+
import io.github.jbellis.jvector.graph.NodeQueue;
3332
import io.github.jbellis.jvector.quantization.CompressedVectors;
3433
import io.github.jbellis.jvector.quantization.ProductQuantization;
3534
import io.github.jbellis.jvector.util.BitSet;
3635
import io.github.jbellis.jvector.util.Bits;
36+
import io.github.jbellis.jvector.util.BoundedLongHeap;
3737
import io.github.jbellis.jvector.util.SparseBits;
3838
import io.github.jbellis.jvector.vector.VectorizationProvider;
3939
import io.github.jbellis.jvector.vector.types.VectorFloat;
@@ -54,13 +54,14 @@
5454
import org.apache.cassandra.index.sai.disk.vector.BruteForceRowIdIterator;
5555
import org.apache.cassandra.index.sai.disk.vector.CassandraDiskAnn;
5656
import org.apache.cassandra.index.sai.disk.vector.CloseableReranker;
57+
import org.apache.cassandra.index.sai.disk.vector.NodeQueueRowIdIterator;
5758
import org.apache.cassandra.index.sai.disk.vector.VectorCompression;
5859
import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex;
5960
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
6061
import org.apache.cassandra.index.sai.plan.Expression;
6162
import org.apache.cassandra.index.sai.plan.Orderer;
6263
import org.apache.cassandra.index.sai.plan.Plan.CostCoefficients;
63-
import org.apache.cassandra.index.sai.utils.IntIntPairArray;
64+
import org.apache.cassandra.index.sai.utils.SegmentRowIdOrdinalPairs;
6465
import org.apache.cassandra.index.sai.utils.PrimaryKey;
6566
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
6667
import org.apache.cassandra.index.sai.utils.RangeUtil;
@@ -72,7 +73,6 @@
7273
import org.apache.cassandra.metrics.QuickSlidingWindowReservoir;
7374
import org.apache.cassandra.tracing.Tracing;
7475
import org.apache.cassandra.utils.CloseableIterator;
75-
import org.apache.cassandra.utils.SortingIterator;
7676

7777
import static java.lang.Math.ceil;
7878
import static java.lang.Math.min;
@@ -230,7 +230,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
230230
if (initialCostEstimate.shouldUseBruteForce())
231231
{
232232
var maxSize = endSegmentRowId - startSegmentRowId + 1;
233-
var segmentOrdinalPairs = new IntIntPairArray(maxSize);
233+
var segmentOrdinalPairs = new SegmentRowIdOrdinalPairs(maxSize);
234234
try (var ordinalsView = graph.getOrdinalsView())
235235
{
236236
ordinalsView.forEachOrdinalInRange(startSegmentRowId, endSegmentRowId, segmentOrdinalPairs::add);
@@ -270,7 +270,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
270270
}
271271
}
272272

273-
private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> queryVector, IntIntPairArray segmentOrdinalPairs, int limit, int rerankK) throws IOException
273+
private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> queryVector, SegmentRowIdOrdinalPairs segmentOrdinalPairs, int limit, int rerankK) throws IOException
274274
{
275275
// If we use compressed vectors, we still have to order rerankK results using full resolution similarity
276276
// scores, so only use the compressed vectors when there are enough vectors to make it worthwhile.
@@ -289,33 +289,44 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> query
289289
*/
290290
private CloseableIterator<RowIdWithScore> orderByBruteForce(CompressedVectors cv,
291291
VectorFloat<?> queryVector,
292-
IntIntPairArray segmentOrdinalPairs,
292+
SegmentRowIdOrdinalPairs segmentOrdinalPairs,
293293
int limit,
294294
int rerankK) throws IOException
295295
{
296-
var approximateScores = new SortingIterator.Builder<BruteForceRowIdIterator.RowWithApproximateScore>(segmentOrdinalPairs.size());
296+
// Use the jvector NodeQueue to avoid unnecessary object allocations since this part of the code operates on
297+
// many rows.
298+
var approximateScores = new NodeQueue(new BoundedLongHeap(segmentOrdinalPairs.size()), NodeQueue.Order.MAX_HEAP);
297299
var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction();
298300
var scoreFunction = cv.precomputedScoreFunctionFor(queryVector, similarityFunction);
299301

300-
segmentOrdinalPairs.forEachIntPair((segmentRowId, ordinal) -> {
301-
var score = scoreFunction.similarityTo(ordinal);
302-
approximateScores.add(new BruteForceRowIdIterator.RowWithApproximateScore(segmentRowId, ordinal, score));
302+
// Store the index of the (rowId, ordinal) pair from the segmentOrdinalPairs in the NodeQueue so that we can
303+
// retrieve both values with O(1) lookup when we need to resolve the full resolution score in the
304+
// BruteForceRowIdIterator.
305+
segmentOrdinalPairs.forEachIndexOrdinalPair((i, ordinal) -> {
306+
approximateScores.push(i, scoreFunction.similarityTo(ordinal));
303307
});
304-
var approximateScoresQueue = approximateScores.build(BruteForceRowIdIterator.RowWithApproximateScore::compare);
305308
var reranker = new CloseableReranker(similarityFunction, queryVector, graph.getView());
306-
return new BruteForceRowIdIterator(approximateScoresQueue, reranker, limit, rerankK);
309+
return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK);
307310
}
308311

309312
/**
310313
* Produces a correct ranking of the rows in the given segment. Because this graph does not have compressed
311314
* vectors, read all vectors and put them into a priority queue to rank them lazily. It is assumed that the whole
312315
* PQ will often not be needed.
313316
*/
314-
private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> queryVector, IntIntPairArray segmentOrdinalPairs) throws IOException
317+
private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> queryVector, SegmentRowIdOrdinalPairs segmentOrdinalPairs) throws IOException
315318
{
316-
var scoredRowIds = new SortingIterator.Builder<RowIdWithScore>(segmentOrdinalPairs.size());
317-
addScoredRowIdsToCollector(queryVector, segmentOrdinalPairs, 0, scoredRowIds::add);
318-
return scoredRowIds.closeable(RowIdWithScore::compare, Runnables.doNothing());
319+
var scoredRowIds = new NodeQueue(new BoundedLongHeap(segmentOrdinalPairs.size()), NodeQueue.Order.MAX_HEAP);
320+
try (var vectorsView = graph.getView())
321+
{
322+
var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction();
323+
var esf = vectorsView.rerankerFor(queryVector, similarityFunction);
324+
// Because the scores are exact, we only store the rowid, score pair.
325+
segmentOrdinalPairs.forEachSegmentRowIdOrdinalPair((segmentRowId, ordinal) -> {
326+
scoredRowIds.push(segmentRowId, esf.similarityTo(ordinal));
327+
});
328+
return new NodeQueueRowIdIterator(scoredRowIds);
329+
}
319330
}
320331

321332
/**
@@ -324,29 +335,21 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> query
324335
* NOTE: because the threshold is not used for ordering, the result is returned in PK order, not score order.
325336
*/
326337
private CloseableIterator<RowIdWithScore> filterByBruteForce(VectorFloat<?> queryVector,
327-
IntIntPairArray segmentOrdinalPairs,
338+
SegmentRowIdOrdinalPairs segmentOrdinalPairs,
328339
float threshold) throws IOException
329340
{
330341
var results = new ArrayList<RowIdWithScore>(segmentOrdinalPairs.size());
331-
addScoredRowIdsToCollector(queryVector, segmentOrdinalPairs, threshold, results::add);
332-
return CloseableIterator.wrap(results.iterator());
333-
}
334-
335-
private void addScoredRowIdsToCollector(VectorFloat<?> queryVector,
336-
IntIntPairArray segmentOrdinalPairs,
337-
float threshold,
338-
Consumer<RowIdWithScore> collector) throws IOException
339-
{
340-
var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction();
341342
try (var vectorsView = graph.getView())
342343
{
344+
var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction();
343345
var esf = vectorsView.rerankerFor(queryVector, similarityFunction);
344-
segmentOrdinalPairs.forEachIntPair((segmentRowId, ordinal) -> {
346+
segmentOrdinalPairs.forEachSegmentRowIdOrdinalPair((segmentRowId, ordinal) -> {
345347
var score = esf.similarityTo(ordinal);
346348
if (score >= threshold)
347-
collector.accept(new RowIdWithScore(segmentRowId, score));
349+
results.add(new RowIdWithScore(segmentRowId, score));
348350
});
349351
}
352+
return CloseableIterator.wrap(results.iterator());
350353
}
351354

352355
private long getMaxSSTableRowId(PrimaryKeyMap primaryKeyMap, PartitionPosition right)
@@ -489,7 +492,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
489492
}
490493
// Create bits from the mapping
491494
var bits = bitSetForSearch();
492-
segmentOrdinalPairs.forEachRightInt(bits::set);
495+
segmentOrdinalPairs.forEachOrdinal(bits::set);
493496
// else ask the index to perform a search limited to the bits we created
494497
var queryVector = vts.createFloatVector(orderer.getVectorTerm());
495498
var results = graph.search(queryVector, limit, rerankK, 0, bits, context, cost::updateStatistics);
@@ -504,9 +507,9 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
504507
* @return a mapping of segment row id to ordinal
505508
* @throws IOException
506509
*/
507-
private IntIntPairArray flatmapPrimaryKeysToBitsAndRows(List<PrimaryKey> keysInRange) throws IOException
510+
private SegmentRowIdOrdinalPairs flatmapPrimaryKeysToBitsAndRows(List<PrimaryKey> keysInRange) throws IOException
508511
{
509-
var segmentOrdinalPairs = new IntIntPairArray(keysInRange.size());
512+
var segmentOrdinalPairs = new SegmentRowIdOrdinalPairs(keysInRange.size());
510513
int lastSegmentRowId = -1;
511514
try (var primaryKeyMap = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();
512515
var ordinalsView = graph.getOrdinalsView())

src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818

1919
package org.apache.cassandra.index.sai.disk.vector;
2020

21+
import io.github.jbellis.jvector.graph.NodeQueue;
22+
import io.github.jbellis.jvector.util.BoundedLongHeap;
23+
import org.apache.cassandra.index.sai.utils.SegmentRowIdOrdinalPairs;
2124
import org.apache.cassandra.index.sai.utils.RowIdWithMeta;
2225
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
2326
import org.apache.cassandra.io.util.FileUtils;
2427
import org.apache.cassandra.utils.AbstractIterator;
25-
import org.apache.cassandra.utils.LucenePriorityQueue;
2628
import org.apache.cassandra.utils.SortingIterator;
2729

2830

@@ -45,55 +47,39 @@
4547
* is consumed. We do this because we expect that most often the first limit-many will pass the final verification
4648
* and only query more if some didn't (e.g. because the vector was deleted in a newer sstable).
4749
* <p>
48-
* As an implementation detail, we use a PriorityQueue to maintain state rather than a List and sorting.
50+
* As an implementation detail, we use a heap to maintain state rather than a List and sorting.
4951
*/
5052
public class BruteForceRowIdIterator extends AbstractIterator<RowIdWithScore>
5153
{
52-
public static class RowWithApproximateScore
53-
{
54-
private final int rowId;
55-
private final int ordinal;
56-
private final float appoximateScore;
57-
58-
public RowWithApproximateScore(int rowId, int ordinal, float appoximateScore)
59-
{
60-
this.rowId = rowId;
61-
this.ordinal = ordinal;
62-
this.appoximateScore = appoximateScore;
63-
}
64-
65-
public static int compare(RowWithApproximateScore l, RowWithApproximateScore r)
66-
{
67-
// Inverted comparison to sort in descending order
68-
return Float.compare(r.appoximateScore, l.appoximateScore);
69-
}
70-
}
71-
72-
// We use two binary heaps (a SortingIterator and LucenePriorityQueue) because we do not need an eager ordering of
54+
// We use two binary heaps (NodeQueue) because we do not need an eager ordering of
7355
// these results. Depending on how many sstables the query hits and the relative scores of vectors from those
7456
// sstables, we may not need to return more than the first handful of scores.
75-
// Priority queue with compressed vector scores
76-
private final SortingIterator<RowWithApproximateScore> approximateScoreQueue;
77-
// Priority queue with full resolution scores
78-
private final LucenePriorityQueue<RowIdWithScore> exactScoreQueue;
57+
// Heap with compressed vector scores
58+
private final NodeQueue approximateScoreQueue;
59+
private final SegmentRowIdOrdinalPairs segmentOrdinalPairs;
60+
// Use the jvector NodeQueue to avoid unnecessary object allocations
61+
private final NodeQueue exactScoreQueue;
7962
private final CloseableReranker reranker;
8063
private final int topK;
8164
private final int limit;
8265
private int rerankedCount;
8366

8467
/**
85-
* @param approximateScoreQueue A priority queue of rows and their ordinal ordered by their approximate similarity scores
68+
* @param approximateScoreQueue A heap of indexes ordered by their approximate similarity scores
69+
* @param segmentOrdinalPairs A mapping from the index in the approximateScoreQueue to the node's rowId and ordinal
8670
* @param reranker A function that takes a graph ordinal and returns the exact similarity score
8771
* @param limit The query limit
8872
* @param topK The number of vectors to resolve and score before returning results
8973
*/
90-
public BruteForceRowIdIterator(SortingIterator<RowWithApproximateScore> approximateScoreQueue,
74+
public BruteForceRowIdIterator(NodeQueue approximateScoreQueue,
75+
SegmentRowIdOrdinalPairs segmentOrdinalPairs,
9176
CloseableReranker reranker,
9277
int limit,
9378
int topK)
9479
{
9580
this.approximateScoreQueue = approximateScoreQueue;
96-
this.exactScoreQueue = new LucenePriorityQueue<>(topK, RowIdWithScore::compare);
81+
this.segmentOrdinalPairs = segmentOrdinalPairs;
82+
this.exactScoreQueue = new NodeQueue(new BoundedLongHeap(topK), NodeQueue.Order.MAX_HEAP);
9783
this.reranker = reranker;
9884
assert topK >= limit : "topK must be greater than or equal to limit. Found: " + topK + " < " + limit;
9985
this.limit = limit;
@@ -106,15 +92,21 @@ protected RowIdWithScore computeNext() {
10692
int consumed = rerankedCount - exactScoreQueue.size();
10793
if (consumed >= limit) {
10894
// Refill the exactScoreQueue until it reaches topK exact scores, or the approximate score queue is empty
109-
while (approximateScoreQueue.hasNext() && exactScoreQueue.size() < topK) {
110-
RowWithApproximateScore rowOrdinalScore = approximateScoreQueue.next();
111-
float score = reranker.similarityTo(rowOrdinalScore.ordinal);
112-
exactScoreQueue.add(new RowIdWithScore(rowOrdinalScore.rowId, score));
95+
while (approximateScoreQueue.size() > 0 && exactScoreQueue.size() < topK) {
96+
int segmentOrdinalIndex = approximateScoreQueue.pop();
97+
int rowId = segmentOrdinalPairs.getSegmentRowId(segmentOrdinalIndex);
98+
int ordinal = segmentOrdinalPairs.getOrdinal(segmentOrdinalIndex);
99+
float score = reranker.similarityTo(ordinal);
100+
exactScoreQueue.push(rowId, score);
113101
}
114102
rerankedCount = exactScoreQueue.size();
115103
}
116-
RowIdWithScore top = exactScoreQueue.pop();
117-
return top == null ? endOfData() : top;
104+
if (exactScoreQueue.size() == 0)
105+
return endOfData();
106+
107+
float score = exactScoreQueue.topScore();
108+
int rowId = exactScoreQueue.pop();
109+
return new RowIdWithScore(rowId, score);
118110
}
119111

120112
@Override
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.apache.cassandra.index.sai.disk.vector;
18+
19+
import io.github.jbellis.jvector.graph.NodeQueue;
20+
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
21+
import org.apache.cassandra.utils.AbstractIterator;
22+
23+
/**
24+
* An iterator over {@link RowIdWithScore} that lazily consumes a {@link NodeQueue}.
25+
*/
26+
public class NodeQueueRowIdIterator extends AbstractIterator<RowIdWithScore>
27+
{
28+
private final NodeQueue scoreQueue;
29+
30+
public NodeQueueRowIdIterator(NodeQueue scoreQueue)
31+
{
32+
this.scoreQueue = scoreQueue;
33+
}
34+
35+
@Override
36+
protected RowIdWithScore computeNext()
37+
{
38+
if (scoreQueue.size() == 0)
39+
return endOfData();
40+
float score = scoreQueue.topScore();
41+
int rowId = scoreQueue.pop();
42+
return new RowIdWithScore(rowId, score);
43+
}
44+
}

0 commit comments

Comments
 (0)