@@ -172,8 +172,15 @@ public static void main(String[] args) throws Exception {
172172 }
173173 }
174174 FormattedResults formattedResults = new FormattedResults ();
175+
175176 for (CmdLineArgs cmdLineArgs : cmdLineArgsList ) {
176- Results result = new Results (cmdLineArgs .indexType ().name ().toLowerCase (Locale .ROOT ), cmdLineArgs .numDocs ());
177+ int [] nProbes = cmdLineArgs .indexType ().equals (IndexType .IVF ) && cmdLineArgs .numQueries () > 0
178+ ? cmdLineArgs .nProbes ()
179+ : new int [] { 0 };
180+ Results [] results = new Results [nProbes .length ];
181+ for (int i = 0 ; i < nProbes .length ; i ++) {
182+ results [i ] = new Results (cmdLineArgs .indexType ().name ().toLowerCase (Locale .ROOT ), cmdLineArgs .numDocs ());
183+ }
177184 logger .info ("Running KNN index tester with arguments: " + cmdLineArgs );
178185 Codec codec = createCodec (cmdLineArgs );
179186 Path indexPath = PathUtils .get (formatIndexPath (cmdLineArgs ));
@@ -192,19 +199,22 @@ public static void main(String[] args) throws Exception {
192199 throw new IllegalArgumentException ("Index path does not exist: " + indexPath );
193200 }
194201 if (cmdLineArgs .reindex ()) {
195- knnIndexer .createIndex (result );
202+ knnIndexer .createIndex (results [ 0 ] );
196203 }
197204 if (cmdLineArgs .forceMerge ()) {
198- knnIndexer .forceMerge (result );
205+ knnIndexer .forceMerge (results [ 0 ] );
199206 } else {
200- knnIndexer .numSegments (result );
207+ knnIndexer .numSegments (results [ 0 ] );
201208 }
202209 }
203210 if (cmdLineArgs .queryVectors () != null && cmdLineArgs .numQueries () > 0 ) {
204- KnnSearcher knnSearcher = new KnnSearcher (indexPath , cmdLineArgs );
205- knnSearcher .runSearch (result );
211+ for (int i = 0 ; i < results .length ; i ++) {
212+ int nProbe = nProbes [i ];
213+ KnnSearcher knnSearcher = new KnnSearcher (indexPath , cmdLineArgs , nProbe );
214+ knnSearcher .runSearch (results [i ]);
215+ }
206216 }
207- formattedResults .results .add ( result );
217+ formattedResults .results .addAll ( List . of ( results ) );
208218 }
209219 logger .info ("Results: \n " + formattedResults );
210220 }
@@ -218,13 +228,12 @@ public String toString() {
218228 return "No results available." ;
219229 }
220230
231+ String [] indexingHeaders = { "index_type" , "num_docs" , "index_time(ms)" , "force_merge_time(ms)" , "num_segments" };
232+
221233 // Define column headers
222- String [] headers = {
234+ String [] searchHeaders = {
223235 "index_type" ,
224- "num_docs" ,
225- "index_time(ms)" ,
226- "force_merge_time(ms)" ,
227- "num_segments" ,
236+ "n_probe" ,
228237 "latency(ms)" ,
229238 "net_cpu_time(ms)" ,
230239 "avg_cpu_count" ,
@@ -233,41 +242,58 @@ public String toString() {
233242 "visited" };
234243
235244 // Calculate appropriate column widths based on headers and data
236- int [] widths = calculateColumnWidths (headers );
237245
238246 StringBuilder sb = new StringBuilder ();
239247
240- // Format and append header
241- sb .append (formatRow (headers , widths ));
242- sb .append ("\n " );
248+ Results indexResult = results .get (0 ); // Assuming all results have the same index type and numDocs
249+ String [] indexData = {
250+ indexResult .indexType ,
251+ Integer .toString (indexResult .numDocs ),
252+ Long .toString (indexResult .indexTimeMS ),
253+ Long .toString (indexResult .forceMergeTimeMS ),
254+ Integer .toString (indexResult .numSegments ) };
243255
244- // Add separator line
245- for (int width : widths ) {
246- sb .append ("-" .repeat (width )).append (" " );
247- }
248- sb .append ("\n " );
256+ printBlock (sb , indexingHeaders , new String [][] { indexData });
249257
258+ String [][] searchData = new String [results .size ()][];
250259 // Format and append each row of data
251- for (Results result : results ) {
252- String [] rowData = {
260+ for (int i = 0 ; i < results .size (); i ++) {
261+ Results result = results .get (i );
262+ searchData [i ] = new String [] {
253263 result .indexType ,
254- Integer .toString (result .numDocs ),
255- Long .toString (result .indexTimeMS ),
256- Long .toString (result .forceMergeTimeMS ),
257- Integer .toString (result .numSegments ),
264+ Integer .toString (result .nProbe ),
258265 String .format (Locale .ROOT , "%.2f" , result .avgLatency ),
259266 String .format (Locale .ROOT , "%.2f" , result .netCpuTimeMS ),
260267 String .format (Locale .ROOT , "%.2f" , result .avgCpuCount ),
261268 String .format (Locale .ROOT , "%.2f" , result .qps ),
262269 String .format (Locale .ROOT , "%.2f" , result .avgRecall ),
263270 String .format (Locale .ROOT , "%.2f" , result .averageVisited ) };
264- sb .append (formatRow (rowData , widths ));
265- sb .append ("\n " );
271+
266272 }
267273
274+ printBlock (sb , searchHeaders , searchData );
275+
268276 return sb .toString ();
269277 }
270278
279+ private void printBlock (StringBuilder sb , String [] headers , String [][] rows ) {
280+ int [] widths = calculateColumnWidths (headers , rows );
281+ sb .append ("\n " );
282+ sb .append (formatRow (headers , widths ));
283+ sb .append ("\n " );
284+
285+ // Add separator line
286+ for (int width : widths ) {
287+ sb .append ("-" .repeat (width )).append (" " );
288+ }
289+ sb .append ("\n " );
290+
291+ for (String [] row : rows ) {
292+ sb .append (formatRow (row , widths ));
293+ sb .append ("\n " );
294+ }
295+ }
296+
271297 // Helper method to format a single row with proper column widths
272298 private String formatRow (String [] values , int [] widths ) {
273299 StringBuilder row = new StringBuilder ();
@@ -285,7 +311,7 @@ private String formatRow(String[] values, int[] widths) {
285311 }
286312
287313 // Calculate appropriate column widths based on headers and data
288- private int [] calculateColumnWidths (String [] headers ) {
314+ private int [] calculateColumnWidths (String [] headers , String []... data ) {
289315 int [] widths = new int [headers .length ];
290316
291317 // Initialize widths with header lengths
@@ -294,20 +320,7 @@ private int[] calculateColumnWidths(String[] headers) {
294320 }
295321
296322 // Update widths based on data
297- for (Results result : results ) {
298- String [] values = {
299- result .indexType ,
300- Integer .toString (result .numDocs ),
301- Long .toString (result .indexTimeMS ),
302- Long .toString (result .forceMergeTimeMS ),
303- Integer .toString (result .numSegments ),
304- String .format (Locale .ROOT , "%.2f" , result .avgLatency ),
305- String .format (Locale .ROOT , "%.2f" , result .netCpuTimeMS ),
306- String .format (Locale .ROOT , "%.2f" , result .avgCpuCount ),
307- String .format (Locale .ROOT , "%.2f" , result .qps ),
308- String .format (Locale .ROOT , "%.2f" , result .avgRecall ),
309- String .format (Locale .ROOT , "%.2f" , result .averageVisited ) };
310-
323+ for (String [] values : data ) {
311324 for (int i = 0 ; i < values .length ; i ++) {
312325 widths [i ] = Math .max (widths [i ], values [i ].length ());
313326 }
@@ -323,6 +336,7 @@ static class Results {
323336 long indexTimeMS ;
324337 long forceMergeTimeMS ;
325338 int numSegments ;
339+ int nProbe ;
326340 double avgLatency ;
327341 double qps ;
328342 double avgRecall ;
0 commit comments