5252import java .util .List ;
5353import java .util .Map ;
5454import java .util .Objects ;
55+ import java .util .OptionalDouble ;
5556import java .util .concurrent .Callable ;
57+ import java .util .stream .Collectors ;
5658
5759import static org .apache .lucene .index .VectorSimilarityFunction .COSINE ;
5860import static org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
@@ -139,25 +141,29 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
139141
140142 List <LeafReaderContext > selectedSegments = new ArrayList <>();
141143
142- double cutoff_affinity = 0.3 ; // minimum affinity score for a segment to be considered
143- double higher_affinity = 0.7 ; // min affinity for increasing nProbe
144- double lower_affinity = 0.6 ; // max affinity for decreasing nProbe
144+ double [] affinityScores = segmentAffinities .stream ().map (SegmentAffinity ::affinityScore ).mapToDouble (Double ::doubleValue ).toArray ();
145145
146- int max_adjustment = 20 ;
146+ // max affinity for decreasing nProbe
147+ double average = Arrays .stream (affinityScores ).average ().orElseThrow ();
148+ double maxAffinity = Arrays .stream (affinityScores ).max ().orElseThrow ();
149+ double lowerAffinity = (maxAffinity + average ) * 0.5 ;
150+ double cutoffAffinity = lowerAffinity * 0.5 ; // minimum affinity score for a segment to be considered
151+ double affinityTreshold = (maxAffinity + lowerAffinity ) * 0.66 ; // min affinity for increasing nProbe
152+ int maxAdjustments = (int ) (nProbe * 1.5 );
147153
148154 Map <LeafReaderContext , Integer > segmentNProbeMap = new HashMap <>();
149155 // Process segments based on their affinity scores
150156 for (SegmentAffinity affinity : segmentAffinities ) {
151157 double score = affinity .affinityScore ();
152158
153159 // Skip segments with very low affinity
154- if (score < cutoff_affinity ) {
160+ if (score < cutoffAffinity ) {
155161 continue ;
156162 }
157163
158164 // Adjust nProbe based on affinity score
159165 // with larger affinity we increase nprobe (and viceversa)
160- int adjustedNProbe = adjustNProbeForSegment (score , higher_affinity , lower_affinity , max_adjustment );
166+ int adjustedNProbe = adjustNProbeForSegment (score , affinityTreshold , maxAdjustments );
161167
162168 // Store the adjusted nProbe value for this segment
163169 segmentNProbeMap .put (affinity .context (), adjustedNProbe );
@@ -179,19 +185,19 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
179185 return new KnnScoreDocQuery (topK .scoreDocs , reader );
180186 }
181187
182- private int adjustNProbeForSegment (double affinityScore , double highThreshold , double lowThreshold , int maxAdjustment ) {
188+ private int adjustNProbeForSegment (double affinityScore , double affinityTreshold , int maxAdjustment ) {
183189 int baseNProbe = this .nProbe ;
184190
185- // For very high affinity scores, increase nProbe
186- if (affinityScore >= highThreshold ) {
187- int adjustment = (int ) Math .ceil ((affinityScore - highThreshold ) * maxAdjustment );
191+ // for high affinity scores, increase nProbe
192+ if (affinityScore > affinityTreshold ) {
193+ int adjustment = (int ) Math .ceil ((affinityScore - affinityTreshold ) * maxAdjustment );
188194 return Math .min (baseNProbe * adjustment , baseNProbe + maxAdjustment );
189195 }
190196
191- // For low affinity scores, decrease nProbe
192- if (affinityScore <= lowThreshold ) {
193- int adjustment = (int ) Math .ceil ((lowThreshold - affinityScore ) * maxAdjustment );
194- return Math .max (baseNProbe / 3 , 1 ); // Ensure nProbe doesn't go below 1
197+ // for low affinity scores, decrease nProbe
198+ if (affinityScore <= affinityTreshold ) {
199+ // int adjustment = (int) Math.ceil((affinityTreshold - affinityScore) * maxAdjustment);
200+ return Math .max (baseNProbe / 3 , 1 );
195201 }
196202
197203 return baseNProbe ;
@@ -233,30 +239,36 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
233239 VectorUtil .l2normalize (queryVector );
234240 }
235241 // similarity between query vector and global centroid, higher is better
236- float globalCentroidScore = similarityFunction .compare (queryVector , globalCentroid );
242+ float centroidsScore = similarityFunction .compare (queryVector , globalCentroid );
237243 if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
238- globalCentroidScore = VectorUtil .scaleMaxInnerProductScore (globalCentroidScore );
244+ centroidsScore = VectorUtil .scaleMaxInnerProductScore (centroidsScore );
239245 }
240246
241247 // clusters per vector (< 1), higher is better (better coverage)
242248 int numCentroids = reader .getNumCentroids (fieldInfo );
243249 double centroidDensity = (double ) numCentroids / leafReader .numDocs ();
244250
245- // include some centroids' scores
246- if (numCentroids > 32 ) {
251+ // with larger clusters, global centroid might not be a good representative,
252+ // so we want to include "some" centroids' scores for higher quality estimate
253+ if (numCentroids > 64 ) {
247254 float [] centroidScores = reader .getCentroidsScores (
248255 fieldInfo ,
249256 numCentroids ,
250257 reader .getIvfCentroids (fieldInfo ),
251258 queryVector ,
252- numCentroids > 64
259+ numCentroids > 128
253260 );
254261 Arrays .sort (centroidScores );
255- globalCentroidScore = (globalCentroidScore + centroidScores [centroidScores .length - 1 ]
256- + centroidScores [centroidScores .length - 2 ]) / 3 ;
262+ float first = centroidScores [centroidScores .length - 1 ];
263+ float second = centroidScores [centroidScores .length - 2 ];
264+ if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
265+ first = VectorUtil .scaleMaxInnerProductScore (first );
266+ second = VectorUtil .scaleMaxInnerProductScore (second );
267+ }
268+ centroidsScore = (centroidsScore + first + second ) / 3 ;
257269 }
258270
259- double affinityScore = globalCentroidScore * (1 + centroidDensity );
271+ double affinityScore = centroidsScore * (1 + centroidDensity );
260272
261273 segmentAffinities .add (new SegmentAffinity (context , affinityScore ));
262274 } else {
0 commit comments