1111
1212import com .carrotsearch .hppc .IntHashSet ;
1313
14+ import org .apache .lucene .index .FloatVectorValues ;
1415import org .apache .lucene .index .IndexReader ;
1516import org .apache .lucene .index .LeafReader ;
1617import org .apache .lucene .index .LeafReaderContext ;
@@ -50,11 +51,10 @@ abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerP
5051 static final TopDocs NO_RESULTS = TopDocsCollector .EMPTY_TOPDOCS ;
5152
5253 protected final String field ;
53- protected final float visitRatio ;
54+ protected final float providedVisitRatio ;
5455 protected final int k ;
5556 protected final int numCands ;
5657 protected final Query filter ;
57- protected final IVFKnnSearchStrategy searchStrategy ;
5858 protected int vectorOpsCount ;
5959
6060 protected AbstractIVFKnnVectorQuery (String field , float visitRatio , int k , int numCands , Query filter ) {
@@ -68,11 +68,10 @@ protected AbstractIVFKnnVectorQuery(String field, float visitRatio, int k, int n
6868 throw new IllegalArgumentException ("numCands must be at least k, got: " + numCands );
6969 }
7070 this .field = field ;
71- this .visitRatio = visitRatio ;
71+ this .providedVisitRatio = visitRatio ;
7272 this .k = k ;
7373 this .filter = filter ;
7474 this .numCands = numCands ;
75- this .searchStrategy = new IVFKnnSearchStrategy (visitRatio );
7675 }
7776
7877 @ Override
@@ -90,12 +89,12 @@ public boolean equals(Object o) {
9089 return k == that .k
9190 && Objects .equals (field , that .field )
9291 && Objects .equals (filter , that .filter )
93- && Objects .equals (visitRatio , that .visitRatio );
92+ && Objects .equals (providedVisitRatio , that .providedVisitRatio );
9493 }
9594
9695 @ Override
9796 public int hashCode () {
98- return Objects .hash (field , k , filter , visitRatio );
97+ return Objects .hash (field , k , filter , providedVisitRatio );
9998 }
10099
101100 @ Override
@@ -116,16 +115,36 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
116115 } else {
117116 filterWeight = null ;
118117 }
118+
119119 // we request numCands as we are using it as an approximation measure
120120 // we need to ensure we are getting at least 2*k results to ensure we cover overspill duplicates
121121 // TODO move the logic for automatically adjusting percentages to the query, so we can only pass
122122 // 2k to the collector.
123123 KnnCollectorManager knnCollectorManager = getKnnCollectorManager (Math .max (Math .round (2f * k ), numCands ), indexSearcher );
124124 TaskExecutor taskExecutor = indexSearcher .getTaskExecutor ();
125125 List <LeafReaderContext > leafReaderContexts = reader .leaves ();
126+
127+ int totalVectors = 0 ;
128+ for (LeafReaderContext leafReaderContext : leafReaderContexts ) {
129+ LeafReader leafReader = leafReaderContext .reader ();
130+ FloatVectorValues floatVectorValues = leafReader .getFloatVectorValues (field );
131+ if (floatVectorValues != null ) {
132+ totalVectors += floatVectorValues .size ();
133+ }
134+ }
135+
136+ final float visitRatio ;
137+ if (providedVisitRatio == 0.0f ) {
138+ // dynamically set the percentage
139+ float expected = (float ) Math .round (1.75f * Math .log10 (numCands ) * Math .log10 (numCands ) * (numCands ));
140+ visitRatio = expected / totalVectors ;
141+ } else {
142+ visitRatio = providedVisitRatio ;
143+ }
144+
126145 List <Callable <TopDocs >> tasks = new ArrayList <>(leafReaderContexts .size ());
127146 for (LeafReaderContext context : leafReaderContexts ) {
128- tasks .add (() -> searchLeaf (context , filterWeight , knnCollectorManager ));
147+ tasks .add (() -> searchLeaf (context , filterWeight , knnCollectorManager , visitRatio ));
129148 }
130149 TopDocs [] perLeafResults = taskExecutor .invokeAll (tasks ).toArray (TopDocs []::new );
131150
@@ -138,8 +157,9 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
138157 return new KnnScoreDocQuery (topK .scoreDocs , reader );
139158 }
140159
141- private TopDocs searchLeaf (LeafReaderContext ctx , Weight filterWeight , KnnCollectorManager knnCollectorManager ) throws IOException {
142- TopDocs results = getLeafResults (ctx , filterWeight , knnCollectorManager );
160+ private TopDocs searchLeaf (LeafReaderContext ctx , Weight filterWeight , KnnCollectorManager knnCollectorManager , float visitRatio )
161+ throws IOException {
162+ TopDocs results = getLeafResults (ctx , filterWeight , knnCollectorManager , visitRatio );
143163 IntHashSet dedup = new IntHashSet (results .scoreDocs .length * 4 / 3 );
144164 int deduplicateCount = 0 ;
145165 for (ScoreDoc scoreDoc : results .scoreDocs ) {
@@ -159,12 +179,13 @@ private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollec
159179 return new TopDocs (results .totalHits , deduplicatedScoreDocs );
160180 }
161181
162- TopDocs getLeafResults (LeafReaderContext ctx , Weight filterWeight , KnnCollectorManager knnCollectorManager ) throws IOException {
182+ TopDocs getLeafResults (LeafReaderContext ctx , Weight filterWeight , KnnCollectorManager knnCollectorManager , float visitRatio )
183+ throws IOException {
163184 final LeafReader reader = ctx .reader ();
164185 final Bits liveDocs = reader .getLiveDocs ();
165186
166187 if (filterWeight == null ) {
167- return approximateSearch (ctx , liveDocs , Integer .MAX_VALUE , knnCollectorManager );
188+ return approximateSearch (ctx , liveDocs , Integer .MAX_VALUE , knnCollectorManager , visitRatio );
168189 }
169190
170191 Scorer scorer = filterWeight .scorer (ctx );
@@ -174,14 +195,15 @@ TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorM
174195
175196 BitSet acceptDocs = createBitSet (scorer .iterator (), liveDocs , reader .maxDoc ());
176197 final int cost = acceptDocs .cardinality ();
177- return approximateSearch (ctx , acceptDocs , cost + 1 , knnCollectorManager );
198+ return approximateSearch (ctx , acceptDocs , cost + 1 , knnCollectorManager , visitRatio );
178199 }
179200
180201 abstract TopDocs approximateSearch (
181202 LeafReaderContext context ,
182203 Bits acceptDocs ,
183204 int visitedLimit ,
184- KnnCollectorManager knnCollectorManager
205+ KnnCollectorManager knnCollectorManager ,
206+ float visitRatio
185207 ) throws IOException ;
186208
187209 protected KnnCollectorManager getKnnCollectorManager (int k , IndexSearcher searcher ) {
0 commit comments