@@ -138,13 +138,21 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
138138
139139 int totalDocsWVectors = 0 ;
140140 assert this instanceof IVFKnnFloatVectorQuery ;
141+ int [] costs = new int [leafReaderContexts .size ()];
142+ int i = 0 ;
141143 for (LeafReaderContext leafReaderContext : leafReaderContexts ) {
142144 LeafReader leafReader = leafReaderContext .reader ();
143145 FieldInfo fieldInfo = leafReader .getFieldInfos ().fieldInfo (field );
144146 VectorScorer scorer = createVectorScorer (leafReaderContext , fieldInfo );
147+ int cost ;
145148 if (scorer != null ) {
146- totalDocsWVectors += (int ) scorer .iterator ().cost ();
149+ cost = (int ) scorer .iterator ().cost ();
150+ totalDocsWVectors += cost ;
151+ } else {
152+ cost = 0 ;
147153 }
154+ costs [i ] = cost ;
155+ i ++;
148156 }
149157
150158 final float visitRatio ;
@@ -165,7 +173,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
165173 if (leafReaderContexts .isEmpty () == false ) {
166174 // calculate the affinity of each segment to the query vector
167175 // (need information from each segment: no. of clusters, global centroid, density, parent centroids' scores, etc.)
168- List <SegmentAffinity > segmentAffinities = calculateSegmentAffinities (leafReaderContexts , getQueryVector ());
176+ List <SegmentAffinity > segmentAffinities = calculateSegmentAffinities (leafReaderContexts , getQueryVector (), costs );
169177
170178 // TODO: sort segments by affinity score in descending order, and cut the long tail ?
171179 double [] affinityScores = segmentAffinities .stream ()
@@ -182,9 +190,22 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
182190 / filteredAffinityScores .length ;
183191 double stdDev = Math .sqrt (variance );
184192
185- double maxAffinity = averageAffinity + 50 * stdDev ;
186- double cutoffAffinity = averageAffinity - 50 * stdDev ;
187- double affinityThreshold = averageAffinity + stdDev ;
193+ double maxAffinity ;
194+ double cutoffAffinity ;
195+ double affinityThreshold ;
196+
197+ if (stdDev > averageAffinity ) {
198+ // adjust calculation when distribution is very skewed (e.g., if (stdDev > averageAffinity) )
199+ double minAffinity = Arrays .stream (affinityScores ).min ().orElse (Double .NaN );
200+ maxAffinity = Arrays .stream (affinityScores ).max ().orElse (Double .NaN );
201+ double lowerAffinity = (minAffinity + averageAffinity ) * 0.5 ;
202+ cutoffAffinity = lowerAffinity * 0.1 ;
203+ affinityThreshold = (minAffinity + lowerAffinity ) * 0.66 ;
204+ } else {
205+ maxAffinity = averageAffinity + 50 * stdDev ;
206+ cutoffAffinity = averageAffinity - 50 * stdDev ;
207+ affinityThreshold = averageAffinity + stdDev ;
208+ }
188209
189210 float maxAdjustments = visitRatio * 1.5f ;
190211
@@ -198,7 +219,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
198219 double scoreVectorsSum = segmentAffinities .stream ().map (segmentAffinity -> {
199220 double affinity = Double .isNaN (segmentAffinity .affinityScore ) ? maxAffinity : segmentAffinity .affinityScore ;
200221 affinity = Math .clamp (affinity , 0.0 , maxAffinity );
201- return affinity * segmentAffinity .context . reader (). numDocs () ;
222+ return affinity * segmentAffinity .numVectors ;
202223 }).mapToDouble (Double ::doubleValue ).sum ();
203224
204225 for (SegmentAffinity segmentAffinity : segmentAffinities ) {
@@ -210,8 +231,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
210231 LeafReaderContext context = segmentAffinity .context ();
211232
212233 // distribute the budget according to : budgetᵢ = total_budget × (affinityᵢ × |vectors|ᵢ) / ∑ (affinityⱼ × |vectors|ⱼ)
213- int segmentBudget = (int ) (totalBudget * (score * context .reader ().numDocs ()) / scoreVectorsSum );
214- tasks .add (() -> searchLeaf (context , filterWeight , knnCollectorManager , adjustedVisitRatio , Math .max (1 , segmentBudget )));
234+ int segmentBudget = (int ) (totalBudget * (score * segmentAffinity .numVectors ) / scoreVectorsSum );
235+ if (segmentBudget > 0 ) {
236+ tasks .add (
237+ () -> searchLeaf (context , filterWeight , knnCollectorManager , adjustedVisitRatio , Math .max (1 , segmentBudget ))
238+ );
239+ }
215240 }
216241 }
217242 } else {
@@ -260,10 +285,11 @@ private IVFVectorsReader unwrapReader(KnnVectorsReader knnVectorsReader) {
260285 return result ;
261286 }
262287
263- private List <SegmentAffinity > calculateSegmentAffinities (List <LeafReaderContext > leafReaderContexts , float [] queryVector )
288+ private List <SegmentAffinity > calculateSegmentAffinities (List <LeafReaderContext > leafReaderContexts , float [] queryVector , int [] costs )
264289 throws IOException {
265290 List <SegmentAffinity > segmentAffinities = new ArrayList <>(leafReaderContexts .size ());
266291
292+ int i = 0 ;
267293 for (LeafReaderContext context : leafReaderContexts ) {
268294 LeafReader leafReader = context .reader ();
269295 FieldInfo fieldInfo = leafReader .getFieldInfos ().fieldInfo (field );
@@ -293,8 +319,9 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
293319 float centroidsScore = similarityFunction .compare (queryVector , globalCentroid );
294320
295321 // clusters per vector (< 1), higher is better (better coverage)
322+ int numVectors = costs [i ];
296323 int numCentroids = reader .getNumCentroids (fieldInfo );
297- double centroidDensity = (double ) numCentroids / leafReader . numDocs () ;
324+ double centroidDensity = (double ) numCentroids / numVectors ;
298325
299326 // with larger clusters, global centroid might not be a good representative,
300327 // so we want to include "some" centroids' scores for higher quality estimate
@@ -315,17 +342,18 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
315342
316343 double affinityScore = centroidsScore * (1 + centroidDensity );
317344
318- segmentAffinities .add (new SegmentAffinity (context , affinityScore ));
345+ segmentAffinities .add (new SegmentAffinity (context , affinityScore , numVectors ));
319346 } else {
320- segmentAffinities .add (new SegmentAffinity (context , 0.5 ));
347+ segmentAffinities .add (new SegmentAffinity (context , Float . NaN , 0 ));
321348 }
322349 }
350+ i ++;
323351 }
324352
325353 return segmentAffinities ;
326354 }
327355
328- private record SegmentAffinity (LeafReaderContext context , double affinityScore ) {}
356+ private record SegmentAffinity (LeafReaderContext context , double affinityScore , int numVectors ) {}
329357
330358 private TopDocs searchLeaf (
331359 LeafReaderContext ctx ,
0 commit comments