Skip to content

Commit 22ce362

Browse files
committed
don't add task with 0 budget, adjust thresholds for skewed affinities, minor improvements
1 parent f3afcd6 commit 22ce362

File tree

1 file changed

+41
-13
lines changed

1 file changed

+41
-13
lines changed

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,21 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
138138

139139
int totalDocsWVectors = 0;
140140
assert this instanceof IVFKnnFloatVectorQuery;
141+
int[] costs = new int[leafReaderContexts.size()];
142+
int i = 0;
141143
for (LeafReaderContext leafReaderContext : leafReaderContexts) {
142144
LeafReader leafReader = leafReaderContext.reader();
143145
FieldInfo fieldInfo = leafReader.getFieldInfos().fieldInfo(field);
144146
VectorScorer scorer = createVectorScorer(leafReaderContext, fieldInfo);
147+
int cost;
145148
if (scorer != null) {
146-
totalDocsWVectors += (int) scorer.iterator().cost();
149+
cost = (int) scorer.iterator().cost();
150+
totalDocsWVectors += cost;
151+
} else {
152+
cost = 0;
147153
}
154+
costs[i] = cost;
155+
i++;
148156
}
149157

150158
final float visitRatio;
@@ -165,7 +173,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
165173
if (leafReaderContexts.isEmpty() == false) {
166174
// calculate the affinity of each segment to the query vector
167175
// (need information from each segment: no. of clusters, global centroid, density, parent centroids' scores, etc.)
168-
List<SegmentAffinity> segmentAffinities = calculateSegmentAffinities(leafReaderContexts, getQueryVector());
176+
List<SegmentAffinity> segmentAffinities = calculateSegmentAffinities(leafReaderContexts, getQueryVector(), costs);
169177

170178
// TODO: sort segments by affinity score in descending order, and cut the long tail ?
171179
double[] affinityScores = segmentAffinities.stream()
@@ -182,9 +190,22 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
182190
/ filteredAffinityScores.length;
183191
double stdDev = Math.sqrt(variance);
184192

185-
double maxAffinity = averageAffinity + 50 * stdDev;
186-
double cutoffAffinity = averageAffinity - 50 * stdDev;
187-
double affinityThreshold = averageAffinity + stdDev;
193+
double maxAffinity;
194+
double cutoffAffinity;
195+
double affinityThreshold;
196+
197+
if (stdDev > averageAffinity) {
198+
// adjust calculation when distribution is very skewed (e.g., if (stdDev > averageAffinity) )
199+
double minAffinity = Arrays.stream(affinityScores).min().orElse(Double.NaN);
200+
maxAffinity = Arrays.stream(affinityScores).max().orElse(Double.NaN);
201+
double lowerAffinity = (minAffinity + averageAffinity) * 0.5;
202+
cutoffAffinity = lowerAffinity * 0.1;
203+
affinityThreshold = (minAffinity + lowerAffinity) * 0.66;
204+
} else {
205+
maxAffinity = averageAffinity + 50 * stdDev;
206+
cutoffAffinity = averageAffinity - 50 * stdDev;
207+
affinityThreshold = averageAffinity + stdDev;
208+
}
188209

189210
float maxAdjustments = visitRatio * 1.5f;
190211

@@ -198,7 +219,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
198219
double scoreVectorsSum = segmentAffinities.stream().map(segmentAffinity -> {
199220
double affinity = Double.isNaN(segmentAffinity.affinityScore) ? maxAffinity : segmentAffinity.affinityScore;
200221
affinity = Math.clamp(affinity, 0.0, maxAffinity);
201-
return affinity * segmentAffinity.context.reader().numDocs();
222+
return affinity * segmentAffinity.numVectors;
202223
}).mapToDouble(Double::doubleValue).sum();
203224

204225
for (SegmentAffinity segmentAffinity : segmentAffinities) {
@@ -210,8 +231,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
210231
LeafReaderContext context = segmentAffinity.context();
211232

212233
// distribute the budget according to : budgetᵢ = total_budget × (affinityᵢ × |vectors|ᵢ) / ∑ (affinityⱼ × |vectors|ⱼ)
213-
int segmentBudget = (int) (totalBudget * (score * context.reader().numDocs()) / scoreVectorsSum);
214-
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, adjustedVisitRatio, Math.max(1, segmentBudget)));
234+
int segmentBudget = (int) (totalBudget * (score * segmentAffinity.numVectors) / scoreVectorsSum);
235+
if (segmentBudget > 0) {
236+
tasks.add(
237+
() -> searchLeaf(context, filterWeight, knnCollectorManager, adjustedVisitRatio, Math.max(1, segmentBudget))
238+
);
239+
}
215240
}
216241
}
217242
} else {
@@ -260,10 +285,11 @@ private IVFVectorsReader unwrapReader(KnnVectorsReader knnVectorsReader) {
260285
return result;
261286
}
262287

263-
private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext> leafReaderContexts, float[] queryVector)
288+
private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext> leafReaderContexts, float[] queryVector, int[] costs)
264289
throws IOException {
265290
List<SegmentAffinity> segmentAffinities = new ArrayList<>(leafReaderContexts.size());
266291

292+
int i = 0;
267293
for (LeafReaderContext context : leafReaderContexts) {
268294
LeafReader leafReader = context.reader();
269295
FieldInfo fieldInfo = leafReader.getFieldInfos().fieldInfo(field);
@@ -293,8 +319,9 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
293319
float centroidsScore = similarityFunction.compare(queryVector, globalCentroid);
294320

295321
// clusters per vector (< 1), higher is better (better coverage)
322+
int numVectors = costs[i];
296323
int numCentroids = reader.getNumCentroids(fieldInfo);
297-
double centroidDensity = (double) numCentroids / leafReader.numDocs();
324+
double centroidDensity = (double) numCentroids / numVectors;
298325

299326
// with larger clusters, global centroid might not be a good representative,
300327
// so we want to include "some" centroids' scores for higher quality estimate
@@ -315,17 +342,18 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
315342

316343
double affinityScore = centroidsScore * (1 + centroidDensity);
317344

318-
segmentAffinities.add(new SegmentAffinity(context, affinityScore));
345+
segmentAffinities.add(new SegmentAffinity(context, affinityScore, numVectors));
319346
} else {
320-
segmentAffinities.add(new SegmentAffinity(context, 0.5));
347+
segmentAffinities.add(new SegmentAffinity(context, Float.NaN, 0));
321348
}
322349
}
350+
i++;
323351
}
324352

325353
return segmentAffinities;
326354
}
327355

328-
private record SegmentAffinity(LeafReaderContext context, double affinityScore) {}
356+
private record SegmentAffinity(LeafReaderContext context, double affinityScore, int numVectors) {}
329357

330358
private TopDocs searchLeaf(
331359
LeafReaderContext ctx,

0 commit comments

Comments
 (0)