22
22
import java .util .ArrayList ;
23
23
import java .util .Collections ;
24
24
import java .util .List ;
25
- import java .util .function .Consumer ;
26
25
27
26
import com .google .common .annotations .VisibleForTesting ;
28
27
import com .google .common .base .MoreObjects ;
29
- import com .google .common .util .concurrent .Runnables ;
30
28
import org .slf4j .Logger ;
31
29
import org .slf4j .LoggerFactory ;
32
30
31
+ import io .github .jbellis .jvector .graph .NodeQueue ;
33
32
import io .github .jbellis .jvector .quantization .CompressedVectors ;
34
33
import io .github .jbellis .jvector .quantization .ProductQuantization ;
35
34
import io .github .jbellis .jvector .util .BitSet ;
36
35
import io .github .jbellis .jvector .util .Bits ;
36
+ import io .github .jbellis .jvector .util .BoundedLongHeap ;
37
37
import io .github .jbellis .jvector .util .SparseBits ;
38
38
import io .github .jbellis .jvector .vector .VectorizationProvider ;
39
39
import io .github .jbellis .jvector .vector .types .VectorFloat ;
54
54
import org .apache .cassandra .index .sai .disk .vector .BruteForceRowIdIterator ;
55
55
import org .apache .cassandra .index .sai .disk .vector .CassandraDiskAnn ;
56
56
import org .apache .cassandra .index .sai .disk .vector .CloseableReranker ;
57
+ import org .apache .cassandra .index .sai .disk .vector .NodeQueueRowIdIterator ;
57
58
import org .apache .cassandra .index .sai .disk .vector .VectorCompression ;
58
59
import org .apache .cassandra .index .sai .disk .vector .VectorMemtableIndex ;
59
60
import org .apache .cassandra .index .sai .iterators .KeyRangeIterator ;
60
61
import org .apache .cassandra .index .sai .plan .Expression ;
61
62
import org .apache .cassandra .index .sai .plan .Orderer ;
62
63
import org .apache .cassandra .index .sai .plan .Plan .CostCoefficients ;
63
- import org .apache .cassandra .index .sai .utils .IntIntPairArray ;
64
+ import org .apache .cassandra .index .sai .utils .SegmentRowIdOrdinalPairs ;
64
65
import org .apache .cassandra .index .sai .utils .PrimaryKey ;
65
66
import org .apache .cassandra .index .sai .utils .PrimaryKeyWithSortKey ;
66
67
import org .apache .cassandra .index .sai .utils .RangeUtil ;
72
73
import org .apache .cassandra .metrics .QuickSlidingWindowReservoir ;
73
74
import org .apache .cassandra .tracing .Tracing ;
74
75
import org .apache .cassandra .utils .CloseableIterator ;
75
- import org .apache .cassandra .utils .SortingIterator ;
76
76
77
77
import static java .lang .Math .ceil ;
78
78
import static java .lang .Math .min ;
@@ -230,7 +230,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
230
230
if (initialCostEstimate .shouldUseBruteForce ())
231
231
{
232
232
var maxSize = endSegmentRowId - startSegmentRowId + 1 ;
233
- var segmentOrdinalPairs = new IntIntPairArray (maxSize );
233
+ var segmentOrdinalPairs = new SegmentRowIdOrdinalPairs (maxSize );
234
234
try (var ordinalsView = graph .getOrdinalsView ())
235
235
{
236
236
ordinalsView .forEachOrdinalInRange (startSegmentRowId , endSegmentRowId , segmentOrdinalPairs ::add );
@@ -270,7 +270,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
270
270
}
271
271
}
272
272
273
- private CloseableIterator <RowIdWithScore > orderByBruteForce (VectorFloat <?> queryVector , IntIntPairArray segmentOrdinalPairs , int limit , int rerankK ) throws IOException
273
+ private CloseableIterator <RowIdWithScore > orderByBruteForce (VectorFloat <?> queryVector , SegmentRowIdOrdinalPairs segmentOrdinalPairs , int limit , int rerankK ) throws IOException
274
274
{
275
275
// If we use compressed vectors, we still have to order rerankK results using full resolution similarity
276
276
// scores, so only use the compressed vectors when there are enough vectors to make it worthwhile.
@@ -289,33 +289,44 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> query
289
289
*/
290
290
private CloseableIterator <RowIdWithScore > orderByBruteForce (CompressedVectors cv ,
291
291
VectorFloat <?> queryVector ,
292
- IntIntPairArray segmentOrdinalPairs ,
292
+ SegmentRowIdOrdinalPairs segmentOrdinalPairs ,
293
293
int limit ,
294
294
int rerankK ) throws IOException
295
295
{
296
- var approximateScores = new SortingIterator .Builder <BruteForceRowIdIterator .RowWithApproximateScore >(segmentOrdinalPairs .size ());
296
+ // Use the jvector NodeQueue to avoid unnecessary object allocations since this part of the code operates on
297
+ // many rows.
298
+ var approximateScores = new NodeQueue (new BoundedLongHeap (segmentOrdinalPairs .size ()), NodeQueue .Order .MAX_HEAP );
297
299
var similarityFunction = indexContext .getIndexWriterConfig ().getSimilarityFunction ();
298
300
var scoreFunction = cv .precomputedScoreFunctionFor (queryVector , similarityFunction );
299
301
300
- segmentOrdinalPairs .forEachIntPair ((segmentRowId , ordinal ) -> {
301
- var score = scoreFunction .similarityTo (ordinal );
302
- approximateScores .add (new BruteForceRowIdIterator .RowWithApproximateScore (segmentRowId , ordinal , score ));
302
+ // Store the index of the (rowId, ordinal) pair from the segmentOrdinalPairs in the NodeQueue so that we can
303
+ // retrieve both values with O(1) lookup when we need to resolve the full resolution score in the
304
+ // BruteForceRowIdIterator.
305
+ segmentOrdinalPairs .forEachIndexOrdinalPair ((i , ordinal ) -> {
306
+ approximateScores .push (i , scoreFunction .similarityTo (ordinal ));
303
307
});
304
- var approximateScoresQueue = approximateScores .build (BruteForceRowIdIterator .RowWithApproximateScore ::compare );
305
308
var reranker = new CloseableReranker (similarityFunction , queryVector , graph .getView ());
306
- return new BruteForceRowIdIterator (approximateScoresQueue , reranker , limit , rerankK );
309
+ return new BruteForceRowIdIterator (approximateScores , segmentOrdinalPairs , reranker , limit , rerankK );
307
310
}
308
311
309
312
/**
310
313
* Produces a correct ranking of the rows in the given segment. Because this graph does not have compressed
311
314
* vectors, read all vectors and put them into a priority queue to rank them lazily. It is assumed that the whole
312
315
* PQ will often not be needed.
313
316
*/
314
- private CloseableIterator <RowIdWithScore > orderByBruteForce (VectorFloat <?> queryVector , IntIntPairArray segmentOrdinalPairs ) throws IOException
317
+ private CloseableIterator <RowIdWithScore > orderByBruteForce (VectorFloat <?> queryVector , SegmentRowIdOrdinalPairs segmentOrdinalPairs ) throws IOException
315
318
{
316
- var scoredRowIds = new SortingIterator .Builder <RowIdWithScore >(segmentOrdinalPairs .size ());
317
- addScoredRowIdsToCollector (queryVector , segmentOrdinalPairs , 0 , scoredRowIds ::add );
318
- return scoredRowIds .closeable (RowIdWithScore ::compare , Runnables .doNothing ());
319
+ var scoredRowIds = new NodeQueue (new BoundedLongHeap (segmentOrdinalPairs .size ()), NodeQueue .Order .MAX_HEAP );
320
+ try (var vectorsView = graph .getView ())
321
+ {
322
+ var similarityFunction = indexContext .getIndexWriterConfig ().getSimilarityFunction ();
323
+ var esf = vectorsView .rerankerFor (queryVector , similarityFunction );
324
+ // Because the scores are exact, we only store the rowid, score pair.
325
+ segmentOrdinalPairs .forEachSegmentRowIdOrdinalPair ((segmentRowId , ordinal ) -> {
326
+ scoredRowIds .push (segmentRowId , esf .similarityTo (ordinal ));
327
+ });
328
+ return new NodeQueueRowIdIterator (scoredRowIds );
329
+ }
319
330
}
320
331
321
332
/**
@@ -324,29 +335,21 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> query
324
335
* NOTE: because the threshold is not used for ordering, the result is returned in PK order, not score order.
325
336
*/
326
337
private CloseableIterator <RowIdWithScore > filterByBruteForce (VectorFloat <?> queryVector ,
327
- IntIntPairArray segmentOrdinalPairs ,
338
+ SegmentRowIdOrdinalPairs segmentOrdinalPairs ,
328
339
float threshold ) throws IOException
329
340
{
330
341
var results = new ArrayList <RowIdWithScore >(segmentOrdinalPairs .size ());
331
- addScoredRowIdsToCollector (queryVector , segmentOrdinalPairs , threshold , results ::add );
332
- return CloseableIterator .wrap (results .iterator ());
333
- }
334
-
335
- private void addScoredRowIdsToCollector (VectorFloat <?> queryVector ,
336
- IntIntPairArray segmentOrdinalPairs ,
337
- float threshold ,
338
- Consumer <RowIdWithScore > collector ) throws IOException
339
- {
340
- var similarityFunction = indexContext .getIndexWriterConfig ().getSimilarityFunction ();
341
342
try (var vectorsView = graph .getView ())
342
343
{
344
+ var similarityFunction = indexContext .getIndexWriterConfig ().getSimilarityFunction ();
343
345
var esf = vectorsView .rerankerFor (queryVector , similarityFunction );
344
- segmentOrdinalPairs .forEachIntPair ((segmentRowId , ordinal ) -> {
346
+ segmentOrdinalPairs .forEachSegmentRowIdOrdinalPair ((segmentRowId , ordinal ) -> {
345
347
var score = esf .similarityTo (ordinal );
346
348
if (score >= threshold )
347
- collector . accept (new RowIdWithScore (segmentRowId , score ));
349
+ results . add (new RowIdWithScore (segmentRowId , score ));
348
350
});
349
351
}
352
+ return CloseableIterator .wrap (results .iterator ());
350
353
}
351
354
352
355
private long getMaxSSTableRowId (PrimaryKeyMap primaryKeyMap , PartitionPosition right )
@@ -489,7 +492,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
489
492
}
490
493
// Create bits from the mapping
491
494
var bits = bitSetForSearch ();
492
- segmentOrdinalPairs .forEachRightInt (bits ::set );
495
+ segmentOrdinalPairs .forEachOrdinal (bits ::set );
493
496
// else ask the index to perform a search limited to the bits we created
494
497
var queryVector = vts .createFloatVector (orderer .getVectorTerm ());
495
498
var results = graph .search (queryVector , limit , rerankK , 0 , bits , context , cost ::updateStatistics );
@@ -504,9 +507,9 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
504
507
* @return a mapping of segment row id to ordinal
505
508
* @throws IOException
506
509
*/
507
- private IntIntPairArray flatmapPrimaryKeysToBitsAndRows (List <PrimaryKey > keysInRange ) throws IOException
510
+ private SegmentRowIdOrdinalPairs flatmapPrimaryKeysToBitsAndRows (List <PrimaryKey > keysInRange ) throws IOException
508
511
{
509
- var segmentOrdinalPairs = new IntIntPairArray (keysInRange .size ());
512
+ var segmentOrdinalPairs = new SegmentRowIdOrdinalPairs (keysInRange .size ());
510
513
int lastSegmentRowId = -1 ;
511
514
try (var primaryKeyMap = primaryKeyMapFactory .newPerSSTablePrimaryKeyMap ();
512
515
var ordinalsView = graph .getOrdinalsView ())
0 commit comments