1717package io .github .jbellis .jvector .example ;
1818
1919import io .github .jbellis .jvector .disk .ReaderSupplierFactory ;
20+ import io .github .jbellis .jvector .example .util .AccuracyMetrics ;
2021import io .github .jbellis .jvector .example .util .CompressorParameters ;
2122import io .github .jbellis .jvector .example .util .DataSet ;
2223import io .github .jbellis .jvector .graph .GraphIndex ;
5556import java .nio .file .Files ;
5657import java .nio .file .Path ;
5758import java .nio .file .Paths ;
58- import java .util .Arrays ;
5959import java .util .EnumMap ;
6060import java .util .HashMap ;
6161import java .util .IdentityHashMap ;
@@ -86,6 +86,7 @@ static void runAll(DataSet ds,
8686 List <? extends Set <FeatureId >> featureSets ,
8787 List <Function <DataSet , CompressorParameters >> buildCompressors ,
8888 List <Function <DataSet , CompressorParameters >> compressionGrid ,
89+ List <Integer > topKGrid ,
8990 List <Double > efSearchFactor ,
9091 List <Boolean > usePruningGrid ) throws IOException
9192 {
@@ -97,7 +98,7 @@ static void runAll(DataSet ds,
9798 for (int efC : efConstructionGrid ) {
9899 for (var bc : buildCompressors ) {
99100 var compressor = getCompressor (bc , ds );
100- runOneGraph (featureSets , M , efC , neighborOverflow , addHierarchy , compressor , compressionGrid , efSearchFactor , usePruningGrid , ds , testDirectory );
101+ runOneGraph (featureSets , M , efC , neighborOverflow , addHierarchy , compressor , compressionGrid , topKGrid , efSearchFactor , usePruningGrid , ds , testDirectory );
101102 }
102103 }
103104 }
@@ -123,6 +124,7 @@ static void runOneGraph(List<? extends Set<FeatureId>> featureSets,
123124 boolean addHierarchy ,
124125 VectorCompressor <?> buildCompressor ,
125126 List <Function <DataSet , CompressorParameters >> compressionGrid ,
127+ List <Integer > topKGrid ,
126128 List <Double > efSearchOptions ,
127129 List <Boolean > usePruningGrid ,
128130 DataSet ds ,
@@ -151,7 +153,7 @@ static void runOneGraph(List<? extends Set<FeatureId>> featureSets,
151153 indexes .forEach ((features , index ) -> {
152154 try (var cs = new ConfiguredSystem (ds , index , cv ,
153155 index instanceof OnDiskGraphIndex ? ((OnDiskGraphIndex ) index ).getFeatureSet () : Set .of ())) {
154- testConfiguration (cs , efSearchOptions , usePruningGrid );
156+ testConfiguration (cs , topKGrid , efSearchOptions , usePruningGrid );
155157 } catch (Exception e ) {
156158 throw new RuntimeException (e );
157159 }
@@ -361,22 +363,20 @@ private static Map<Set<FeatureId>, GraphIndex> buildInMemory(List<? extends Set<
361363 // avoid recomputing the compressor repeatedly (this is a relatively small memory footprint)
362364 static final Map <String , VectorCompressor <?>> cachedCompressors = new IdentityHashMap <>();
363365
364- private static void testConfiguration (ConfiguredSystem cs , List <Double > efSearchOptions , List <Boolean > usePruningGrid ) {
365- var topK = cs .ds .groundTruth .get (0 ).size ();
366+ private static void testConfiguration (ConfiguredSystem cs , List <Integer > topKGrid , List <Double > efSearchOptions , List <Boolean > usePruningGrid ) {
366367 int queryRuns = 2 ;
367368 System .out .format ("Using %s:%n" , cs .index );
368- for (var overquery : efSearchOptions ) {
369- int rerankK = (int ) (topK * overquery );
370- for (var usePruning : usePruningGrid ) {
371- var startTime = System .nanoTime ();
372- var pqr = performQueries (cs , topK , rerankK , usePruning , queryRuns );
373- var stopTime = System .nanoTime ();
374- var recall = ((double ) pqr .topKFound ) / (queryRuns * cs .ds .queryVectors .size () * topK );
375- System .out .format (" Query top %d/%d recall %.4f in %.2fms after %.2f nodes visited (AVG) and %.2f nodes expanded with pruning=%b%n" ,
376- topK , rerankK , recall , (stopTime - startTime ) / (queryRuns * 1_000_000.0 ),
377- (double ) pqr .nodesVisited / (queryRuns * cs .ds .queryVectors .size ()),
378- (double ) pqr .nodesExpanded / (queryRuns * cs .ds .queryVectors .size ()),
379- usePruning );
369+ for (var topK : topKGrid ) {
370+ for (var overquery : efSearchOptions ) {
371+ int rerankK = (int ) (topK * overquery );
372+ for (var usePruning : usePruningGrid ) {
373+ var pqr = performQueries (cs , topK , rerankK , usePruning , queryRuns );
374+ System .out .format (" Query top %d/%d recall %.4f in %.2fms after %.2f nodes visited (AVG) and %.2f nodes expanded with pruning=%b%n" ,
375+ topK , rerankK , pqr .recall , pqr .runtime / (1_000_000.0 ),
376+ (double ) pqr .nodesVisited / cs .ds .queryVectors .size (),
377+ (double ) pqr .nodesExpanded / cs .ds .queryVectors .size (),
378+ usePruning );
379+ }
380380 }
381381 }
382382 }
@@ -421,44 +421,36 @@ private static VectorCompressor<?> getCompressor(Function<DataSet, CompressorPar
421421 });
422422 }
423423
424- private static long topKCorrect (int topK , int [] resultNodes , Set <Integer > gt ) {
425- int count = Math .min (resultNodes .length , topK );
426- var resultSet = Arrays .stream (resultNodes , 0 , count )
427- .boxed ()
428- .collect (Collectors .toSet ());
429- assert resultSet .size () == count : String .format ("%s duplicate results out of %s" , count - resultSet .size (), count );
430- return resultSet .stream ().filter (gt ::contains ).count ();
431- }
432-
433- private static long topKCorrect (int topK , SearchResult .NodeScore [] nn , Set <Integer > gt ) {
434- var a = Arrays .stream (nn ).mapToInt (nodeScore -> nodeScore .node ).toArray ();
435- return topKCorrect (topK , a , gt );
436- }
437-
438424 private static ResultSummary performQueries (ConfiguredSystem cs , int topK , int rerankK , boolean usePruning , int queryRuns ) {
439- LongAdder topKfound = new LongAdder ();
440425 LongAdder nodesVisited = new LongAdder ();
441426 LongAdder nodesExpanded = new LongAdder ();
442427 LongAdder nodesExpandedBaseLayer = new LongAdder ();
428+ LongAdder runtime = new LongAdder ();
429+ double recall = 0 ;
430+
443431 for (int k = 0 ; k < queryRuns ; k ++) {
444- IntStream .range (0 , cs .ds .queryVectors .size ()).parallel ().forEach (i -> {
432+ var startTime = System .nanoTime ();
433+ List <SearchResult > listSR = IntStream .range (0 , cs .ds .queryVectors .size ()).parallel ().mapToObj (i -> {
445434 var queryVector = cs .ds .queryVectors .get (i );
446435 SearchResult sr ;
447436 var searcher = cs .getSearcher ();
448437 searcher .usePruning (usePruning );
449438 var sf = cs .scoreProviderFor (queryVector , searcher .getView ());
450439 sr = searcher .search (sf , topK , rerankK , 0.0f , 0.0f , Bits .ALL );
440+ return sr ;
441+ }).collect (Collectors .toList ());
442+ var stopTime = System .nanoTime ();
451443
452- // process search result
453- var gt = cs .ds .groundTruth .get (i );
454- var n = topKCorrect (topK , sr .getNodes (), gt );
455- topKfound .add (n );
444+ runtime .add (stopTime - startTime );
445+ // process search result
446+ listSR .stream ().parallel ().forEach (sr -> {
456447 nodesVisited .add (sr .getVisitedCount ());
457448 nodesExpanded .add (sr .getExpandedCount ());
458449 nodesExpandedBaseLayer .add (sr .getExpandedCountBaseLayer ());
459450 });
451+ recall += AccuracyMetrics .recallFromSearchResults (cs .ds .groundTruth , listSR , topK , topK );
460452 }
461- return new ResultSummary (( int ) topKfound .sum (), nodesVisited .sum (), nodesExpanded .sum (), nodesExpandedBaseLayer .sum ());
453+ return new ResultSummary (recall / queryRuns , nodesVisited .sum () / queryRuns , nodesExpanded .sum () / queryRuns , nodesExpandedBaseLayer .sum () / queryRuns , runtime .sum () / queryRuns );
462454 }
463455
464456 static class ConfiguredSystem implements AutoCloseable {
@@ -506,16 +498,18 @@ public void close() throws Exception {
506498 }
507499
508500 static class ResultSummary {
509- final int topKFound ;
501+ final double recall ;
510502 final long nodesVisited ;
511503 final long nodesExpanded ;
512504 final long nodesExpandedBaseLayer ;
505+ final long runtime ;
513506
514- ResultSummary (int topKFound , long nodesVisited , long nodesExpanded , long nodesExpandedBaseLayer ) {
515- this .topKFound = topKFound ;
507+ ResultSummary (double recall , long nodesVisited , long nodesExpanded , long nodesExpandedBaseLayer , long runtime ) {
508+ this .recall = recall ;
516509 this .nodesVisited = nodesVisited ;
517510 this .nodesExpanded = nodesExpanded ;
518511 this .nodesExpandedBaseLayer = nodesExpandedBaseLayer ;
512+ this .runtime = runtime ;
519513 }
520514 }
521515}
0 commit comments