5555
5656import java .io .IOException ;
5757import java .io .OutputStream ;
58+ import java .io .UncheckedIOException ;
5859import java .nio .ByteBuffer ;
5960import java .nio .ByteOrder ;
6061import java .nio .IntBuffer ;
7172import java .util .concurrent .ExecutorService ;
7273import java .util .concurrent .Executors ;
7374import java .util .concurrent .ForkJoinPool ;
75+ import java .util .concurrent .Future ;
7476import java .util .concurrent .TimeUnit ;
77+ import java .util .function .IntConsumer ;
7578
7679import static org .apache .lucene .search .DocIdSetIterator .NO_MORE_DOCS ;
7780import static org .elasticsearch .test .knn .KnnIndexTester .logger ;
@@ -96,6 +99,7 @@ class KnnSearcher {
9699 private final VectorEncoding vectorEncoding ;
97100 private final float overSamplingFactor ;
98101 private final int searchThreads ;
102+ private final int numSearchers ;
99103
100104 KnnSearcher (Path indexPath , CmdLineArgs cmdLineArgs , int nProbe ) {
101105 this .docPath = cmdLineArgs .docVectors ();
@@ -115,6 +119,7 @@ class KnnSearcher {
115119 this .nProbe = nProbe ;
116120 this .indexType = cmdLineArgs .indexType ();
117121 this .searchThreads = cmdLineArgs .searchThreads ();
122+ this .numSearchers = cmdLineArgs .numSearchers ();
118123 }
119124
120125 void runSearch (KnnIndexTester .Results finalResults , boolean earlyTermination ) throws IOException {
@@ -124,7 +129,10 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
124129 int offsetByteSize = 0 ;
125130 try (
126131 FileChannel input = FileChannel .open (queryPath );
127- ExecutorService executorService = Executors .newFixedThreadPool (searchThreads , r -> new Thread (r , "KnnSearcher-Thread" ))
132+ ExecutorService executorService = Executors .newFixedThreadPool (searchThreads , r -> new Thread (r , "KnnSearcher-Thread" ));
133+ ExecutorService numSearchersExecutor = numSearchers > 1
134+ ? Executors .newFixedThreadPool (numSearchers , r -> new Thread (r , "KnnSearcher-Caller" ))
135+ : null
128136 ) {
129137 long queryPathSizeInBytes = input .size ();
130138 logger .info (
@@ -163,29 +171,87 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
163171 }
164172 }
165173 targetReader .reset ();
174+ final IntConsumer [] queryConsumers = new IntConsumer [numSearchers ];
175+ if (vectorEncoding .equals (VectorEncoding .BYTE )) {
176+ byte [][] queries = new byte [numQueryVectors ][dim ];
177+ for (int i = 0 ; i < numQueryVectors ; i ++) {
178+ targetReader .next (queries [i ]);
179+ }
180+ for (int s = 0 ; s < numSearchers ; s ++) {
181+ queryConsumers [s ] = i -> {
182+ try {
183+ results [i ] = doVectorQuery (queries [i ], searcher , earlyTermination );
184+ } catch (IOException e ) {
185+ throw new UncheckedIOException (e );
186+ }
187+ };
188+ }
189+ } else {
190+ float [][] queries = new float [numQueryVectors ][dim ];
191+ for (int i = 0 ; i < numQueryVectors ; i ++) {
192+ targetReader .next (queries [i ]);
193+ }
194+ for (int s = 0 ; s < numSearchers ; s ++) {
195+ queryConsumers [s ] = i -> {
196+ try {
197+ results [i ] = doVectorQuery (queries [i ], searcher , earlyTermination );
198+ } catch (IOException e ) {
199+ throw new UncheckedIOException (e );
200+ }
201+ };
202+ }
203+ }
204+ int [][] querySplits = new int [numSearchers ][];
205+ int queriesPerSearcher = numQueryVectors / numSearchers ;
206+ for (int s = 0 ; s < numSearchers ; s ++) {
207+ int start = s * queriesPerSearcher ;
208+ int end = (s == numSearchers - 1 ) ? numQueryVectors : (s + 1 ) * queriesPerSearcher ;
209+ querySplits [s ] = new int [end - start ];
210+ for (int i = start ; i < end ; i ++) {
211+ querySplits [s ][i - start ] = i ;
212+ }
213+ }
214+ targetReader .reset ();
166215 startNS = System .nanoTime ();
167216 KnnIndexTester .ThreadDetails startThreadDetails = new KnnIndexTester .ThreadDetails ();
168- for (int i = 0 ; i < numQueryVectors ; i ++) {
169- if (vectorEncoding .equals (VectorEncoding .BYTE )) {
170- targetReader .next (targetBytes );
171- results [i ] = doVectorQuery (targetBytes , searcher , earlyTermination );
172- } else {
173- targetReader .next (target );
174- results [i ] = doVectorQuery (target , searcher , earlyTermination );
217+ if (numSearchersExecutor != null ) {
218+ // use multiple searchers
219+ var futures = new ArrayList <Future <Void >>();
220+ for (int s = 0 ; s < numSearchers ; s ++) {
221+ int [] split = querySplits [s ];
222+ IntConsumer queryConsumer = queryConsumers [s ];
223+ futures .add (numSearchersExecutor .submit (() -> {
224+ for (int j : split ) {
225+ queryConsumer .accept (j );
226+ }
227+ return null ;
228+ }));
229+ }
230+ for (Future <Void > future : futures ) {
231+ try {
232+ future .get ();
233+ } catch (Exception e ) {
234+ throw new RuntimeException ("Error executing searcher thread" , e );
235+ }
236+ }
237+ } else {
238+ // use a single searcher
239+ for (int i = 0 ; i < numQueryVectors ; i ++) {
240+ queryConsumers [0 ].accept (i );
175241 }
176242 }
177243 KnnIndexTester .ThreadDetails endThreadDetails = new KnnIndexTester .ThreadDetails ();
178244 elapsed = TimeUnit .NANOSECONDS .toMillis (System .nanoTime () - startNS );
179245 long startCPUTimeNS = 0 ;
180246 long endCPUTimeNS = 0 ;
181247 for (int i = 0 ; i < startThreadDetails .threadInfos .length ; i ++) {
182- if (startThreadDetails .threadInfos [i ].getThreadName ().startsWith ("KnnSearcher-Thread " )) {
248+ if (startThreadDetails .threadInfos [i ].getThreadName ().startsWith ("KnnSearcher" )) {
183249 startCPUTimeNS += startThreadDetails .cpuTimesNS [i ];
184250 }
185251 }
186252
187253 for (int i = 0 ; i < endThreadDetails .threadInfos .length ; i ++) {
188- if (endThreadDetails .threadInfos [i ].getThreadName ().startsWith ("KnnSearcher-Thread " )) {
254+ if (endThreadDetails .threadInfos [i ].getThreadName ().startsWith ("KnnSearcher" )) {
189255 endCPUTimeNS += endThreadDetails .cpuTimesNS [i ];
190256 }
191257 }
0 commit comments