3333import org .apache .lucene .queries .function .valuesource .FloatKnnVectorFieldSource ;
3434import org .apache .lucene .queries .function .valuesource .FloatVectorSimilarityFunction ;
3535import org .apache .lucene .search .IndexSearcher ;
36+ import org .apache .lucene .search .KnnByteVectorQuery ;
37+ import org .apache .lucene .search .KnnFloatVectorQuery ;
38+ import org .apache .lucene .search .PatienceKnnVectorQuery ;
3639import org .apache .lucene .search .Query ;
3740import org .apache .lucene .search .ScoreDoc ;
3841import org .apache .lucene .search .TopDocs ;
@@ -113,7 +116,7 @@ class KnnSearcher {
113116 this .searchThreads = cmdLineArgs .searchThreads ();
114117 }
115118
116- void runSearch (KnnIndexTester .Results finalResults ) throws IOException {
119+ void runSearch (KnnIndexTester .Results finalResults , boolean earlyTermination ) throws IOException {
117120 TopDocs [] results = new TopDocs [numQueryVectors ];
118121 int [][] resultIds = new int [numQueryVectors ][];
119122 long elapsed , totalCpuTimeMS , totalVisited = 0 ;
@@ -139,10 +142,10 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
139142 for (int i = 0 ; i < numQueryVectors ; i ++) {
140143 if (vectorEncoding .equals (VectorEncoding .BYTE )) {
141144 targetReader .next (targetBytes );
142- doVectorQuery (targetBytes , searcher );
145+ doVectorQuery (targetBytes , searcher , earlyTermination );
143146 } else {
144147 targetReader .next (target );
145- doVectorQuery (target , searcher );
148+ doVectorQuery (target , searcher , earlyTermination );
146149 }
147150 }
148151 targetReader .reset ();
@@ -151,10 +154,10 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
151154 for (int i = 0 ; i < numQueryVectors ; i ++) {
152155 if (vectorEncoding .equals (VectorEncoding .BYTE )) {
153156 targetReader .next (targetBytes );
154- results [i ] = doVectorQuery (targetBytes , searcher );
157+ results [i ] = doVectorQuery (targetBytes , searcher , earlyTermination );
155158 } else {
156159 targetReader .next (target );
157- results [i ] = doVectorQuery (target , searcher );
160+ results [i ] = doVectorQuery (target , searcher , earlyTermination );
158161 }
159162 }
160163 KnnIndexTester .ThreadDetails endThreadDetails = new KnnIndexTester .ThreadDetails ();
@@ -249,7 +252,7 @@ private boolean isNewer(Path path, Path... others) throws IOException {
249252 return true ;
250253 }
251254
252- TopDocs doVectorQuery (byte [] vector , IndexSearcher searcher ) throws IOException {
255+ TopDocs doVectorQuery (byte [] vector , IndexSearcher searcher , boolean earlyTermination ) throws IOException {
253256 Query knnQuery ;
254257 if (overSamplingFactor > 1f ) {
255258 throw new IllegalArgumentException ("oversampling factor > 1 is not supported for byte vectors" );
@@ -265,6 +268,9 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException
265268 null ,
266269 DenseVectorFieldMapper .FilterHeuristic .ACORN .getKnnSearchStrategy ()
267270 );
271+ if (indexType == KnnIndexTester .IndexType .HNSW && earlyTermination ) {
272+ knnQuery = PatienceKnnVectorQuery .fromByteQuery ((KnnByteVectorQuery ) knnQuery );
273+ }
268274 }
269275 QueryProfiler profiler = new QueryProfiler ();
270276 TopDocs docs = searcher .search (knnQuery , this .topK );
@@ -273,7 +279,7 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException
273279 return new TopDocs (new TotalHits (profiler .getVectorOpsCount (), docs .totalHits .relation ()), docs .scoreDocs );
274280 }
275281
276- TopDocs doVectorQuery (float [] vector , IndexSearcher searcher ) throws IOException {
282+ TopDocs doVectorQuery (float [] vector , IndexSearcher searcher , boolean earlyTermination ) throws IOException {
277283 Query knnQuery ;
278284 int topK = this .topK ;
279285 if (overSamplingFactor > 1f ) {
@@ -292,6 +298,9 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException
292298 null ,
293299 DenseVectorFieldMapper .FilterHeuristic .ACORN .getKnnSearchStrategy ()
294300 );
301+ if (indexType == KnnIndexTester .IndexType .HNSW && earlyTermination ) {
302+ knnQuery = PatienceKnnVectorQuery .fromFloatQuery ((KnnFloatVectorQuery ) knnQuery );
303+ }
295304 }
296305 if (overSamplingFactor > 1f ) {
297306 // oversample the topK results to get more candidates for the final result
0 commit comments