Skip to content

Commit 0635453

Browse files
committed
more sensible and generic threshold definition
1 parent 9499be1 commit 0635453

File tree

1 file changed

+34
-22
lines changed

1 file changed

+34
-22
lines changed

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252
import java.util.List;
5353
import java.util.Map;
5454
import java.util.Objects;
55+
import java.util.OptionalDouble;
5556
import java.util.concurrent.Callable;
57+
import java.util.stream.Collectors;
5658

5759
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
5860
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
@@ -139,25 +141,29 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
139141

140142
List<LeafReaderContext> selectedSegments = new ArrayList<>();
141143

142-
double cutoff_affinity = 0.3; // minimum affinity score for a segment to be considered
143-
double higher_affinity = 0.7; // min affinity for increasing nProbe
144-
double lower_affinity = 0.6 ; // max affinity for decreasing nProbe
144+
double[] affinityScores = segmentAffinities.stream().map(SegmentAffinity::affinityScore).mapToDouble(Double::doubleValue).toArray();
145145

146-
int max_adjustment = 20;
146+
// max affinity for decreasing nProbe
147+
double average = Arrays.stream(affinityScores).average().orElseThrow();
148+
double maxAffinity = Arrays.stream(affinityScores).max().orElseThrow();
149+
double lowerAffinity = (maxAffinity + average) * 0.5;
150+
double cutoffAffinity = lowerAffinity * 0.5; // minimum affinity score for a segment to be considered
151+
double affinityTreshold = (maxAffinity + lowerAffinity) * 0.66; // min affinity for increasing nProbe
152+
int maxAdjustments = (int) (nProbe * 1.5);
147153

148154
Map<LeafReaderContext, Integer> segmentNProbeMap = new HashMap<>();
149155
// Process segments based on their affinity scores
150156
for (SegmentAffinity affinity : segmentAffinities) {
151157
double score = affinity.affinityScore();
152158

153159
// Skip segments with very low affinity
154-
if (score < cutoff_affinity) {
160+
if (score < cutoffAffinity) {
155161
continue;
156162
}
157163

158164
// Adjust nProbe based on affinity score
159165
// with larger affinity we increase nprobe (and viceversa)
160-
int adjustedNProbe = adjustNProbeForSegment(score, higher_affinity, lower_affinity, max_adjustment);
166+
int adjustedNProbe = adjustNProbeForSegment(score, affinityTreshold, maxAdjustments);
161167

162168
// Store the adjusted nProbe value for this segment
163169
segmentNProbeMap.put(affinity.context(), adjustedNProbe);
@@ -179,19 +185,19 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
179185
return new KnnScoreDocQuery(topK.scoreDocs, reader);
180186
}
181187

182-
private int adjustNProbeForSegment(double affinityScore, double highThreshold, double lowThreshold, int maxAdjustment) {
188+
private int adjustNProbeForSegment(double affinityScore, double affinityTreshold, int maxAdjustment) {
183189
int baseNProbe = this.nProbe;
184190

185-
// For very high affinity scores, increase nProbe
186-
if (affinityScore >= highThreshold) {
187-
int adjustment = (int) Math.ceil((affinityScore - highThreshold) * maxAdjustment);
191+
// for high affinity scores, increase nProbe
192+
if (affinityScore > affinityTreshold) {
193+
int adjustment = (int) Math.ceil((affinityScore - affinityTreshold) * maxAdjustment);
188194
return Math.min(baseNProbe * adjustment, baseNProbe + maxAdjustment);
189195
}
190196

191-
// For low affinity scores, decrease nProbe
192-
if (affinityScore <= lowThreshold) {
193-
int adjustment = (int) Math.ceil((lowThreshold - affinityScore) * maxAdjustment);
194-
return Math.max(baseNProbe / 3, 1); // Ensure nProbe doesn't go below 1
197+
// for low affinity scores, decrease nProbe
198+
if (affinityScore <= affinityTreshold) {
199+
//int adjustment = (int) Math.ceil((affinityTreshold - affinityScore) * maxAdjustment);
200+
return Math.max(baseNProbe / 3, 1);
195201
}
196202

197203
return baseNProbe;
@@ -233,30 +239,36 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
233239
VectorUtil.l2normalize(queryVector);
234240
}
235241
// similarity between query vector and global centroid, higher is better
236-
float globalCentroidScore = similarityFunction.compare(queryVector, globalCentroid);
242+
float centroidsScore = similarityFunction.compare(queryVector, globalCentroid);
237243
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
238-
globalCentroidScore = VectorUtil.scaleMaxInnerProductScore(globalCentroidScore);
244+
centroidsScore = VectorUtil.scaleMaxInnerProductScore(centroidsScore);
239245
}
240246

241247
// clusters per vector (< 1), higher is better (better coverage)
242248
int numCentroids = reader.getNumCentroids(fieldInfo);
243249
double centroidDensity = (double) numCentroids / leafReader.numDocs();
244250

245-
// include some centroids' scores
246-
if (numCentroids > 32) {
251+
// with larger clusters, global centroid might not be a good representative,
252+
// so we want to include "some" centroids' scores for higher quality estimate
253+
if (numCentroids > 64) {
247254
float[] centroidScores = reader.getCentroidsScores(
248255
fieldInfo,
249256
numCentroids,
250257
reader.getIvfCentroids(fieldInfo),
251258
queryVector,
252-
numCentroids > 64
259+
numCentroids > 128
253260
);
254261
Arrays.sort(centroidScores);
255-
globalCentroidScore = (globalCentroidScore + centroidScores[centroidScores.length - 1]
256-
+ centroidScores[centroidScores.length - 2]) / 3;
262+
float first = centroidScores[centroidScores.length - 1];
263+
float second = centroidScores[centroidScores.length - 2];
264+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
265+
first = VectorUtil.scaleMaxInnerProductScore(first);
266+
second = VectorUtil.scaleMaxInnerProductScore(second);
267+
}
268+
centroidsScore = (centroidsScore + first + second) / 3;
257269
}
258270

259-
double affinityScore = globalCentroidScore * (1 + centroidDensity);
271+
double affinityScore = centroidsScore * (1 + centroidDensity);
260272

261273
segmentAffinities.add(new SegmentAffinity(context, affinityScore));
262274
} else {

0 commit comments

Comments
 (0)