6161abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {
6262
6363 static final TopDocs NO_RESULTS = TopDocsCollector .EMPTY_TOPDOCS ;
64+ public static final double VECTOR_VISITED_PERCENTAGE_BUDGET = 0.05 ;
6465
6566 protected final String field ;
6667 protected final int nProbe ;
@@ -131,15 +132,20 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
131132 KnnCollectorManager knnCollectorManager = getKnnCollectorManager (numCands , indexSearcher );
132133 TaskExecutor taskExecutor = indexSearcher .getTaskExecutor ();
133134 List <LeafReaderContext > leafReaderContexts = reader .leaves ();
135+
136+ int totalBudget = (int ) (reader .numDocs () * VECTOR_VISITED_PERCENTAGE_BUDGET );
137+
134138 List <Callable <TopDocs >> tasks ;
135139 if (leafReaderContexts .isEmpty () == false ) {
136-
137140 // calculate the affinity of each segment to the query vector
138141 // (need information from each segment: no. of clusters, global centroid, density, parent centroids' scores, etc.)
139142 List <SegmentAffinity > segmentAffinities = calculateSegmentAffinities (leafReaderContexts , getQueryVector ());
140143
141144 // TODO: sort segments by affinity score in descending order, and cut the long tail ?
142- double [] affinityScores = segmentAffinities .stream ().map (SegmentAffinity ::affinityScore ).mapToDouble (Double ::doubleValue ).toArray ();
145+ double [] affinityScores = segmentAffinities .stream ()
146+ .map (SegmentAffinity ::affinityScore )
147+ .mapToDouble (Double ::doubleValue )
148+ .toArray ();
143149
144150 // max affinity for decreasing nProbe
145151 double averageAffinity = Arrays .stream (affinityScores ).average ().orElse (Double .NaN );
@@ -152,7 +158,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
152158 if (Double .isNaN (maxAffinity ) || Double .isNaN (averageAffinity )) {
153159 tasks = new ArrayList <>(leafReaderContexts .size ());
154160 for (LeafReaderContext context : leafReaderContexts ) {
155- tasks .add (() -> searchLeaf (context , filterWeight , knnCollectorManager , nProbe ));
161+ tasks .add (() -> searchLeaf (context , filterWeight , knnCollectorManager , nProbe , Integer . MAX_VALUE ));
156162 }
157163 } else {
158164 Map <LeafReaderContext , Integer > segmentNProbeMap = new HashMap <>();
@@ -173,8 +179,19 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
173179 }
174180
175181 tasks = new ArrayList <>(segmentNProbeMap .size ());
176- for (Map .Entry <LeafReaderContext , Integer > entry : segmentNProbeMap .entrySet ()) {
177- tasks .add (() -> searchLeaf (entry .getKey (), filterWeight , knnCollectorManager , entry .getValue ()));
182+ double scoreVectorsSum = segmentAffinities .stream ()
183+ .map (segmentAffinity -> segmentAffinity .affinityScore * segmentAffinity .context .reader ().numDocs ())
184+ .mapToDouble (Double ::doubleValue )
185+ .sum ();
186+
187+ for (SegmentAffinity segmentAffinity : segmentAffinities ) {
188+ double score = segmentAffinity .affinityScore ();
189+ int adjustedNProbe = adjustNProbeForSegment (score , affinityTreshold , maxAdjustments );
190+ LeafReaderContext context = segmentAffinity .context ();
191+
192+ // budgetᵢ = total_budget × (affinityᵢ × |vectors|ᵢ) / ∑ (affinityⱼ × |vectors|ⱼ)
193+ int segmentBudget = (int ) (totalBudget * (score * context .reader ().numDocs ()) / scoreVectorsSum );
194+ tasks .add (() -> searchLeaf (context , filterWeight , knnCollectorManager , adjustedNProbe , Math .max (1 , segmentBudget )));
178195 }
179196 }
180197 } else {
@@ -191,18 +208,17 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
191208 return new KnnScoreDocQuery (topK .scoreDocs , reader );
192209 }
193210
194- private int adjustNProbeForSegment (double affinityScore , double affinityTreshold , int maxAdjustment ) {
211+ private int adjustNProbeForSegment (double affinityScore , double affinityThreshold , int maxAdjustment ) {
195212 int baseNProbe = this .nProbe ;
196213
197214 // for high affinity scores, increase nProbe
198- if (affinityScore > affinityTreshold ) {
199- int adjustment = (int ) Math .ceil ((affinityScore - affinityTreshold ) * maxAdjustment );
215+ if (affinityScore > affinityThreshold ) {
216+ int adjustment = (int ) Math .ceil ((affinityScore - affinityThreshold ) * maxAdjustment );
200217 return Math .min (baseNProbe * adjustment , baseNProbe + maxAdjustment );
201218 }
202219
203220 // for low affinity scores, decrease nProbe
204- if (affinityScore <= affinityTreshold ) {
205- // int adjustment = (int) Math.ceil((affinityTreshold - affinityScore) * maxAdjustment);
221+ if (affinityScore <= affinityThreshold ) {
206222 return Math .max (baseNProbe / 3 , 1 );
207223 }
208224
@@ -288,9 +304,14 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
288304
289305 private record SegmentAffinity (LeafReaderContext context , double affinityScore ) {}
290306
291- private TopDocs searchLeaf (LeafReaderContext ctx , Weight filterWeight , KnnCollectorManager knnCollectorManager , int nProbe )
292- throws IOException {
293- TopDocs results = getLeafResults (ctx , filterWeight , knnCollectorManager , nProbe );
307+ private TopDocs searchLeaf (
308+ LeafReaderContext ctx ,
309+ Weight filterWeight ,
310+ KnnCollectorManager knnCollectorManager ,
311+ int nProbe ,
312+ int visitingBudget
313+ ) throws IOException {
314+ TopDocs results = getLeafResults (ctx , filterWeight , knnCollectorManager , nProbe , visitingBudget );
294315 if (ctx .docBase > 0 ) {
295316 for (ScoreDoc scoreDoc : results .scoreDocs ) {
296317 scoreDoc .doc += ctx .docBase ;
@@ -299,15 +320,20 @@ private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollec
299320 return results ;
300321 }
301322
302- TopDocs getLeafResults (LeafReaderContext ctx , Weight filterWeight , KnnCollectorManager knnCollectorManager , int nProbe )
303- throws IOException {
323+ TopDocs getLeafResults (
324+ LeafReaderContext ctx ,
325+ Weight filterWeight ,
326+ KnnCollectorManager knnCollectorManager ,
327+ int nProbe ,
328+ int visitingBudget
329+ ) throws IOException {
304330 final LeafReader reader = ctx .reader ();
305331 final Bits liveDocs = reader .getLiveDocs ();
306332
307333 KnnSearchStrategy searchStrategy = new IVFKnnSearchStrategy (nProbe );
308334
309335 if (filterWeight == null ) {
310- return approximateSearch (ctx , liveDocs , Integer . MAX_VALUE , knnCollectorManager , searchStrategy );
336+ return approximateSearch (ctx , liveDocs , visitingBudget , knnCollectorManager , searchStrategy );
311337 }
312338
313339 Scorer scorer = filterWeight .scorer (ctx );
@@ -317,7 +343,7 @@ TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorM
317343
318344 BitSet acceptDocs = createBitSet (scorer .iterator (), liveDocs , reader .maxDoc ());
319345 final int cost = acceptDocs .cardinality ();
320- return approximateSearch (ctx , acceptDocs , cost + 1 , knnCollectorManager , searchStrategy );
346+ return approximateSearch (ctx , acceptDocs , Math . min ( visitingBudget , cost + 1 ) , knnCollectorManager , searchStrategy );
321347 }
322348
323349 abstract TopDocs approximateSearch (
0 commit comments