1717package io .github .jbellis .jvector .example ;
1818
1919import io .github .jbellis .jvector .disk .ReaderSupplierFactory ;
20- import io .github .jbellis .jvector .example .benchmarks .*;
21- import io .github .jbellis .jvector .example .util .AccuracyMetrics ;
20+ import io .github .jbellis .jvector .example .benchmarks .AccuracyBenchmark ;
21+ import io .github .jbellis .jvector .example .benchmarks .BenchmarkTablePrinter ;
22+ import io .github .jbellis .jvector .example .benchmarks .CountBenchmark ;
23+ import io .github .jbellis .jvector .example .benchmarks .LatencyBenchmark ;
24+ import io .github .jbellis .jvector .example .benchmarks .QueryBenchmark ;
25+ import io .github .jbellis .jvector .example .benchmarks .QueryTester ;
26+ import io .github .jbellis .jvector .example .benchmarks .ThroughputBenchmark ;
2227import io .github .jbellis .jvector .example .util .CompressorParameters ;
2328import io .github .jbellis .jvector .example .util .DataSet ;
2429import io .github .jbellis .jvector .example .util .FilteredForkJoinPool ;
2732import io .github .jbellis .jvector .graph .GraphSearcher ;
2833import io .github .jbellis .jvector .graph .OnHeapGraphIndex ;
2934import io .github .jbellis .jvector .graph .RandomAccessVectorValues ;
30- import io .github .jbellis .jvector .graph .SearchResult ;
3135import io .github .jbellis .jvector .graph .disk .feature .Feature ;
3236import io .github .jbellis .jvector .graph .disk .feature .FeatureId ;
3337import io .github .jbellis .jvector .graph .disk .feature .FusedADC ;
4549import io .github .jbellis .jvector .quantization .PQVectors ;
4650import io .github .jbellis .jvector .quantization .ProductQuantization ;
4751import io .github .jbellis .jvector .quantization .VectorCompressor ;
48- import io .github .jbellis .jvector .util .Bits ;
4952import io .github .jbellis .jvector .util .ExplicitThreadLocal ;
5053import io .github .jbellis .jvector .util .PhysicalCoreExecutor ;
5154import io .github .jbellis .jvector .vector .types .VectorFloat ;
5962import java .nio .file .Files ;
6063import java .nio .file .Path ;
6164import java .nio .file .Paths ;
65+ import java .util .ArrayList ;
6266import java .util .EnumMap ;
6367import java .util .HashMap ;
6468import java .util .IdentityHashMap ;
6569import java .util .List ;
6670import java .util .Map ;
6771import java .util .Set ;
68- import java .util .concurrent .atomic .LongAdder ;
69- import java .util .concurrent .ForkJoinPool ;
7072import java .util .function .Function ;
7173import java .util .function .IntFunction ;
72- import java .util .stream .Collectors ;
7374import java .util .stream .IntStream ;
7475
7576/**
@@ -91,7 +92,8 @@ static void runAll(DataSet ds,
9192 List <Function <DataSet , CompressorParameters >> buildCompressors ,
9293 List <Function <DataSet , CompressorParameters >> compressionGrid ,
9394 Map <Integer , List <Double >> topKGrid ,
94- List <Boolean > usePruningGrid ) throws IOException
95+ List <Boolean > usePruningGrid ,
96+ Map <String , List <String >> benchmarks ) throws IOException
9597 {
9698 var testDirectory = Files .createTempDirectory (dirPrefix );
9799 try {
@@ -102,7 +104,7 @@ static void runAll(DataSet ds,
102104 for (int efC : efConstructionGrid ) {
103105 for (var bc : buildCompressors ) {
104106 var compressor = getCompressor (bc , ds );
105- runOneGraph (featureSets , M , efC , neighborOverflow , addHierarchy , refineFinalGraph , compressor , compressionGrid , topKGrid , usePruningGrid , ds , testDirectory );
107+ runOneGraph (featureSets , M , efC , neighborOverflow , addHierarchy , refineFinalGraph , compressor , compressionGrid , topKGrid , usePruningGrid , benchmarks , ds , testDirectory );
106108 }
107109 }
108110 }
@@ -122,6 +124,21 @@ static void runAll(DataSet ds,
122124 }
123125 }
124126
127+ static void runAll (DataSet ds ,
128+ List <Integer > mGrid ,
129+ List <Integer > efConstructionGrid ,
130+ List <Float > neighborOverflowGrid ,
131+ List <Boolean > addHierarchyGrid ,
132+ List <Boolean > refineFinalGraphGrid ,
133+ List <? extends Set <FeatureId >> featureSets ,
134+ List <Function <DataSet , CompressorParameters >> buildCompressors ,
135+ List <Function <DataSet , CompressorParameters >> compressionGrid ,
136+ Map <Integer , List <Double >> topKGrid ,
137+ List <Boolean > usePruningGrid ) throws IOException
138+ {
139+ runAll (ds , mGrid , efConstructionGrid , neighborOverflowGrid , addHierarchyGrid , refineFinalGraphGrid , featureSets , buildCompressors , compressionGrid , topKGrid , usePruningGrid , null );
140+ }
141+
125142 static void runOneGraph (List <? extends Set <FeatureId >> featureSets ,
126143 int M ,
127144 int efConstruction ,
@@ -132,6 +149,7 @@ static void runOneGraph(List<? extends Set<FeatureId>> featureSets,
132149 List <Function <DataSet , CompressorParameters >> compressionGrid ,
133150 Map <Integer , List <Double >> topKGrid ,
134151 List <Boolean > usePruningGrid ,
152+ Map <String , List <String >> benchmarks ,
135153 DataSet ds ,
136154 Path testDirectory ) throws IOException
137155 {
@@ -158,7 +176,7 @@ static void runOneGraph(List<? extends Set<FeatureId>> featureSets,
158176 indexes .forEach ((features , index ) -> {
159177 try (var cs = new ConfiguredSystem (ds , index , cv ,
160178 index instanceof OnDiskGraphIndex ? ((OnDiskGraphIndex ) index ).getFeatureSet () : Set .of ())) {
161- testConfiguration (cs , topKGrid , usePruningGrid , M , efConstruction , neighborOverflow , addHierarchy );
179+ testConfiguration (cs , topKGrid , usePruningGrid , M , efConstruction , neighborOverflow , addHierarchy , benchmarks );
162180 } catch (Exception e ) {
163181 throw new RuntimeException (e );
164182 }
@@ -379,17 +397,13 @@ private static void testConfiguration(ConfiguredSystem cs,
379397 int M ,
380398 int efConstruction ,
381399 float neighborOverflow ,
382- boolean addHierarchy ) {
400+ boolean addHierarchy ,
401+ Map <String , List <String >> benchmarkSpec ) {
383402 int queryRuns = 2 ;
384403 System .out .format ("Using %s:%n" , cs .index );
385404 // 1) Select benchmarks to run. Use .createDefault or .createEmpty (for other options)
386- List <QueryBenchmark > benchmarks = List .of (
387- ThroughputBenchmark .createEmpty (3 , 3 )
388- .displayAvgQps (),
389- LatencyBenchmark .createDefault (),
390- CountBenchmark .createDefault (),
391- AccuracyBenchmark .createDefault ()
392- );
405+
406+ var benchmarks = setupBenchmarks (benchmarkSpec );
393407 QueryTester tester = new QueryTester (benchmarks );
394408
395409 // 2) Setup benchmark table for printing
@@ -414,6 +428,85 @@ private static void testConfiguration(ConfiguredSystem cs,
414428 }
415429 }
416430
431+ private static List <QueryBenchmark > setupBenchmarks (Map <String , List <String >> benchmarkSpec ) {
432+ if (benchmarkSpec == null || benchmarkSpec .isEmpty ()) {
433+ return List .of (
434+ ThroughputBenchmark .createEmpty (3 , 3 )
435+ .displayAvgQps (),
436+ LatencyBenchmark .createDefault (),
437+ CountBenchmark .createDefault (),
438+ AccuracyBenchmark .createDefault ()
439+ );
440+ }
441+
442+ List <QueryBenchmark > benchmarks = new ArrayList <>();
443+
444+ for (var benchType : benchmarkSpec .keySet ()) {
445+ if (benchType .equals ("throughput" )) {
446+ var bench = ThroughputBenchmark .createEmpty (3 , 3 );
447+ for (var stat : benchmarkSpec .get (benchType )) {
448+ if (stat .equals ("AVG" )) {
449+ bench = bench .displayAvgQps ();
450+ }
451+ if (stat .equals ("MEDIAN" )) {
452+ bench = bench .displayMedianQps ();
453+ }
454+ if (stat .equals ("MAX" )) {
455+ bench = bench .displayMaxQps ();
456+ }
457+ }
458+ benchmarks .add (bench );
459+ }
460+
461+ if (benchType .equals ("latency" )) {
462+ var bench = LatencyBenchmark .createEmpty ();
463+ for (var stat : benchmarkSpec .get (benchType )) {
464+ if (stat .equals ("AVG" )) {
465+ bench = bench .displayAvgLatency ();
466+ }
467+ if (stat .equals ("STD" )) {
468+ bench = bench .displayLatencySTD ();
469+ }
470+ if (stat .equals ("P999" )) {
471+ bench = bench .displayP999Latency ();
472+ }
473+ }
474+ benchmarks .add (bench );
475+ }
476+
477+ if (benchType .equals ("count" )) {
478+ var bench = CountBenchmark .createEmpty ();
479+ for (var stat : benchmarkSpec .get (benchType )) {
480+ if (stat .equals ("visited" )) {
481+ bench = bench .displayAvgNodesVisited ();
482+ }
483+ if (stat .equals ("expanded" )) {
484+ bench = bench .displayAvgNodesExpanded ();
485+ }
486+ if (stat .equals ("expanded base layer" )) {
487+ bench = bench .displayAvgNodesExpandedBaseLayer ();
488+ }
489+ }
490+ benchmarks .add (bench );
491+ }
492+
493+ if (benchType .equals ("accuracy" )) {
494+ var bench = AccuracyBenchmark .createEmpty ();
495+ for (var stat : benchmarkSpec .get (benchType )) {
496+ if (stat .equals ("recall" )) {
497+ bench = bench .displayRecall ();
498+ }
499+ if (stat .equals ("MAP" )) {
500+ bench = bench .displayMAP ();
501+ }
502+ }
503+ benchmarks .add (bench );
504+ }
505+ }
506+
507+ return benchmarks ;
508+ }
509+
417510 private static VectorCompressor <?> getCompressor (Function <DataSet , CompressorParameters > cpSupplier , DataSet ds ) {
418511 var cp = cpSupplier .apply (ds );
419512 if (!cp .supportsCaching ()) {
0 commit comments