@@ -62,6 +62,9 @@ abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerP
6262
6363 static final TopDocs NO_RESULTS = TopDocsCollector .EMPTY_TOPDOCS ;
6464
65+ private static final float MAX_VISIT_INCREASE_RATIO = 0.1f ;
66+ private static final double MAX_VISIT_DECREASE_RATIO = 0.01f ;
67+
6568 protected final String field ;
6669 protected final float providedVisitRatio ;
6770 protected final int k ;
@@ -166,79 +169,48 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
166169 visitRatio = providedVisitRatio ;
167170 }
168171
169- // FIXME: pick a reasonable min budget and make sure visitRatio is used appropriately throughout and not pushing the budget to zero
170- int totalBudget = Math .max (100 , (int ) (totalDocsWVectors * visitRatio ));
171-
172172 List <Callable <TopDocs >> tasks ;
173173 if (leafReaderContexts .isEmpty () == false ) {
174174 // calculate the affinity of each segment to the query vector
175- // (need information from each segment: no. of clusters, global centroid, density, parent centroids' scores, etc.)
176175 List <SegmentAffinity > segmentAffinities = calculateSegmentAffinities (leafReaderContexts , getQueryVector (), costs );
176+ segmentAffinities .sort ((a , b ) -> Double .compare (b .affinityScore (), a .affinityScore ()));
177177
178- // TODO: sort segments by affinity score in descending order, and cut the long tail ?
179178 double [] affinityScores = segmentAffinities .stream ()
180179 .map (SegmentAffinity ::affinityScore )
181180 .mapToDouble (Double ::doubleValue )
182- .toArray ();
183-
184- double [] filteredAffinityScores = Arrays .stream (affinityScores )
185181 .filter (x -> Double .isNaN (x ) == false && Double .isInfinite (x ) == false )
186182 .toArray ();
187183
188- final double averageAffinity = Arrays .stream (filteredAffinityScores ).average ().orElse (0.0 );
189- double variance = Arrays .stream (filteredAffinityScores ).map (x -> (x - averageAffinity ) * (x - averageAffinity )).sum ()
190- / filteredAffinityScores .length ;
191- double stdDev = Math .sqrt (variance );
192-
193- double maxAffinity ;
194- double cutoffAffinity ;
195- double affinityThreshold ;
196-
197- if (stdDev > averageAffinity ) {
198- // adjust calculation when distribution is very skewed
199- double minAffinity = Arrays .stream (affinityScores ).min ().orElse (Double .NaN );
200- maxAffinity = Arrays .stream (affinityScores ).max ().orElse (Double .NaN );
201- double lowerAffinity = (minAffinity + averageAffinity ) * 0.5 ;
202- cutoffAffinity = lowerAffinity * 0.1 ;
203- affinityThreshold = (minAffinity + lowerAffinity ) * 0.66 ;
204- } else {
205- maxAffinity = averageAffinity + 50 * stdDev ;
206- cutoffAffinity = averageAffinity - 50 * stdDev ;
207- affinityThreshold = averageAffinity + stdDev ;
208- }
184+ double minAffinity = Arrays .stream (affinityScores ).min ().orElse (Double .NaN );
185+ double maxAffinity = Arrays .stream (affinityScores ).max ().orElse (Double .NaN );
209186
210- float maxAdjustments = visitRatio * 1.5f ;
187+ double [] normalizedAffinityScores = Arrays .stream (affinityScores )
188+ .map (d -> (d - minAffinity ) / (maxAffinity - minAffinity ))
189+ .toArray ();
211190
212- if (Double . isNaN ( maxAffinity ) || Double . isNaN ( averageAffinity )) {
191+ if (normalizedAffinityScores . length != segmentAffinities . size ( )) {
213192 tasks = new ArrayList <>(leafReaderContexts .size ());
214193 for (LeafReaderContext context : leafReaderContexts ) {
215- tasks .add (() -> searchLeaf (context , filterWeight , knnCollectorManager , visitRatio , Integer . MAX_VALUE ));
194+ tasks .add (() -> searchLeaf (context , filterWeight , knnCollectorManager , visitRatio ));
216195 }
217196 } else {
218197 tasks = new ArrayList <>(segmentAffinities .size ());
219- double scoreVectorsSum = segmentAffinities .stream ().map (segmentAffinity -> {
220- double affinity = Double .isNaN (segmentAffinity .affinityScore ) ? maxAffinity : segmentAffinity .affinityScore ;
221- affinity = Math .clamp (affinity , 0.0 , maxAffinity );
222- return affinity * segmentAffinity .numVectors ;
223- }).mapToDouble (Double ::doubleValue ).sum ();
224-
198+ int j = 0 ;
225199 for (SegmentAffinity segmentAffinity : segmentAffinities ) {
226- double score = segmentAffinity .affinityScore ();
227- if (score < cutoffAffinity ) {
228- continue ;
229- }
230- float adjustedVisitRatio = adjustVisitRatioForSegment (score , affinityThreshold , maxAdjustments , visitRatio );
231- LeafReaderContext context = segmentAffinity .context ();
200+ double normalizedAffinityScore = normalizedAffinityScores [j ];
232201
233- // distribute the budget according to : budgetᵢ = total_budget × (affinityᵢ × |vectors|ᵢ) / ∑ (affinityⱼ × |vectors|ⱼ)
234- int segmentBudget = (int ) (totalBudget * (score * segmentAffinity .numVectors ) / scoreVectorsSum );
202+ float adjustedVisitRatio = (float ) Math .clamp (
203+ (visitRatio * (MAX_VISIT_INCREASE_RATIO + normalizedAffinityScore )),
204+ visitRatio * MAX_VISIT_DECREASE_RATIO ,
205+ visitRatio * (1 + MAX_VISIT_INCREASE_RATIO )
206+ );
235207
236- // TODO : should we always grant a min budget for each affine-enough segment
237- if (segmentBudget > 0 ) {
238- tasks .add (
239- () -> searchLeaf (context , filterWeight , knnCollectorManager , adjustedVisitRatio , Math .max (1 , segmentBudget ))
240- );
208+ LeafReaderContext context = segmentAffinity .context ();
209+
210+ if (adjustedVisitRatio > 0 ) {
211+ tasks .add (() -> searchLeaf (context , filterWeight , knnCollectorManager , adjustedVisitRatio ));
241212 }
213+ j ++;
242214 }
243215 }
244216 } else {
@@ -257,21 +229,6 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
257229
258230 abstract VectorScorer createVectorScorer (LeafReaderContext context , FieldInfo fi ) throws IOException ;
259231
260- private float adjustVisitRatioForSegment (double affinityScore , double affinityThreshold , float maxAdjustment , float visitRatio ) {
261- // for high affinity scores, increase visited ratio
262- if (affinityScore > affinityThreshold ) {
263- int adjustment = (int ) Math .ceil ((affinityScore - affinityThreshold ) * maxAdjustment );
264- return Math .min (visitRatio * adjustment , visitRatio + maxAdjustment );
265- }
266-
267- // for low affinity scores, decrease visited ratio
268- if (affinityScore <= affinityThreshold ) {
269- return Math .max (visitRatio * 0.5f , 0.01f );
270- }
271-
272- return visitRatio ;
273- }
274-
275232 abstract float [] getQueryVector () throws IOException ;
276233
277234 private IVFVectorsReader unwrapReader (KnnVectorsReader knnVectorsReader ) {
@@ -287,8 +244,7 @@ private IVFVectorsReader unwrapReader(KnnVectorsReader knnVectorsReader) {
287244 return result ;
288245 }
289246
290- private List <SegmentAffinity > calculateSegmentAffinities (List <LeafReaderContext > leafReaderContexts , float [] queryVector , int [] costs )
291- throws IOException {
247+ private List <SegmentAffinity > calculateSegmentAffinities (List <LeafReaderContext > leafReaderContexts , float [] queryVector , int [] costs ) {
292248 List <SegmentAffinity > segmentAffinities = new ArrayList <>(leafReaderContexts .size ());
293249
294250 int i = 0 ;
@@ -325,25 +281,8 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
325281 int numCentroids = reader .getNumCentroids (fieldInfo );
326282 double centroidDensity = (double ) numCentroids / numVectors ;
327283
328- // with larger clusters, global centroid might not be a good representative,
329- // so we want to include "some" centroids' scores for higher quality estimate
330- // TODO: tweak the threshold numCentroids here
331- if (numCentroids > 64 ) {
332- float [] centroidScores = reader .getCentroidsScores (
333- fieldInfo ,
334- numCentroids ,
335- reader .getIvfCentroids (fieldInfo ),
336- queryVector ,
337- numCentroids > 128
338- );
339- Arrays .sort (centroidScores );
340- float first = centroidScores [centroidScores .length - 1 ];
341- float second = centroidScores [centroidScores .length - 2 ];
342- centroidsScore = (centroidsScore + first + second ) / 3 ;
343- }
344-
284+ // TODO : we may want to include some actual centroids' scores for higher quality estimate
345285 double affinityScore = centroidsScore * (1 + centroidDensity );
346-
347286 segmentAffinities .add (new SegmentAffinity (context , affinityScore , numVectors ));
348287 } else {
349288 segmentAffinities .add (new SegmentAffinity (context , Float .NaN , 0 ));
@@ -357,14 +296,9 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
357296
358297 private record SegmentAffinity (LeafReaderContext context , double affinityScore , int numVectors ) {}
359298
360- private TopDocs searchLeaf (
361- LeafReaderContext ctx ,
362- Weight filterWeight ,
363- KnnCollectorManager knnCollectorManager ,
364- float visitRatio ,
365- int visitingBudget
366- ) throws IOException {
367- TopDocs results = getLeafResults (ctx , filterWeight , knnCollectorManager , visitRatio , visitingBudget );
299+ private TopDocs searchLeaf (LeafReaderContext ctx , Weight filterWeight , KnnCollectorManager knnCollectorManager , float visitRatio )
300+ throws IOException {
301+ TopDocs results = getLeafResults (ctx , filterWeight , knnCollectorManager , visitRatio );
368302 IntHashSet dedup = new IntHashSet (results .scoreDocs .length * 4 / 3 );
369303 int deduplicateCount = 0 ;
370304 for (ScoreDoc scoreDoc : results .scoreDocs ) {
@@ -384,19 +318,13 @@ private TopDocs searchLeaf(
384318 return new TopDocs (results .totalHits , deduplicatedScoreDocs );
385319 }
386320
387- TopDocs getLeafResults (
388- LeafReaderContext ctx ,
389- Weight filterWeight ,
390- KnnCollectorManager knnCollectorManager ,
391- float visitRatio ,
392- int visitingBudget
393- ) throws IOException {
321+ TopDocs getLeafResults (LeafReaderContext ctx , Weight filterWeight , KnnCollectorManager knnCollectorManager , float visitRatio )
322+ throws IOException {
394323 final LeafReader reader = ctx .reader ();
395324 final Bits liveDocs = reader .getLiveDocs ();
396325
397- KnnSearchStrategy searchStrategy = new IVFKnnSearchStrategy (visitRatio );
398326 if (filterWeight == null ) {
399- return approximateSearch (ctx , liveDocs , visitingBudget , knnCollectorManager , searchStrategy );
327+ return approximateSearch (ctx , liveDocs , Integer . MAX_VALUE , knnCollectorManager , visitRatio );
400328 }
401329
402330 Scorer scorer = filterWeight .scorer (ctx );
@@ -406,15 +334,15 @@ TopDocs getLeafResults(
406334
407335 BitSet acceptDocs = createBitSet (scorer .iterator (), liveDocs , reader .maxDoc ());
408336 final int cost = acceptDocs .cardinality ();
409- return approximateSearch (ctx , acceptDocs , Math . min ( visitingBudget , cost + 1 ) , knnCollectorManager , searchStrategy );
337+ return approximateSearch (ctx , acceptDocs , cost + 1 , knnCollectorManager , visitRatio );
410338 }
411339
412340 abstract TopDocs approximateSearch (
413341 LeafReaderContext context ,
414342 Bits acceptDocs ,
415343 int visitedLimit ,
416344 KnnCollectorManager knnCollectorManager ,
417- KnnSearchStrategy searchStrategy
345+ float visitRatio
418346 ) throws IOException ;
419347
420348 protected KnnCollectorManager getKnnCollectorManager (int k , IndexSearcher searcher ) {
0 commit comments