Skip to content

Commit d261f02

Browse files
committed
remove explicit budget as it conflicts with visitedRatio, normalized affinity as visited ratio modifier
1 parent d631243 commit d261f02

File tree

4 files changed

+36
-173
lines changed

4 files changed

+36
-173
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -138,63 +138,6 @@ CentroidIterator getCentroidIterator(
138138
return getPostingListPrefetchIterator(centroidIterator, postingListSlice);
139139
}
140140

141-
@Override
142-
public float[] getCentroidsScores(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery, boolean parents)
143-
throws IOException {
144-
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
145-
final float globalCentroidDp = fieldEntry.globalCentroidDp();
146-
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
147-
final int[] scratch = new int[targetQuery.length];
148-
float[] targetQueryCopy = ArrayUtil.copyArray(targetQuery);
149-
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
150-
VectorUtil.l2normalize(targetQueryCopy);
151-
}
152-
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
153-
targetQueryCopy,
154-
scratch,
155-
(byte) 7,
156-
fieldEntry.globalCentroid()
157-
);
158-
final byte[] quantizedQuery = new byte[targetQuery.length];
159-
for (int i = 0; i < quantizedQuery.length; i++) {
160-
quantizedQuery[i] = (byte) scratch[i];
161-
}
162-
final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension());
163-
centroids.seek(0L);
164-
// final scores
165-
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
166-
167-
int numParents = centroids.readVInt();
168-
if (parents && numParents > 0) {
169-
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
170-
score(
171-
parentsQueue,
172-
numParents,
173-
0,
174-
scorer,
175-
quantizedQuery,
176-
queryParams,
177-
globalCentroidDp,
178-
fieldInfo.getVectorSimilarityFunction(),
179-
scores
180-
);
181-
} else {
182-
final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true);
183-
score(
184-
neighborQueue,
185-
numCentroids,
186-
0,
187-
scorer,
188-
quantizedQuery,
189-
queryParams,
190-
globalCentroidDp,
191-
fieldInfo.getVectorSimilarityFunction(),
192-
scores
193-
);
194-
}
195-
return scores;
196-
}
197-
198141
private static CentroidIterator getCentroidIteratorNoParent(
199142
FieldInfo fieldInfo,
200143
IndexInput centroids,

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,6 @@ abstract CentroidIterator getCentroidIterator(
9393
IndexInput postingListSlice
9494
) throws IOException;
9595

96-
public abstract float[] getCentroidsScores(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target, boolean parents)
97-
throws IOException;
98-
9996
private static IndexInput openDataInput(
10097
SegmentReadState state,
10198
int versionMeta,
@@ -272,9 +269,6 @@ public final void search(String field, float[] target, KnnCollector knnCollector
272269
// is enough?
273270
expectedDocs += scorer.resetPostingsScorer(offsetAndLength.offset());
274271
actualDocs += scorer.visit(knnCollector);
275-
if (knnCollector.earlyTerminated()) {
276-
break;
277-
}
278272
}
279273
if (acceptDocs != null) {
280274
float unfilteredRatioVisited = (float) expectedDocs / numVectors;
@@ -284,9 +278,6 @@ public final void search(String field, float[] target, KnnCollector knnCollector
284278
CentroidOffsetAndLength offsetAndLength = centroidPrefetchingIterator.nextPostingListOffsetAndLength();
285279
scorer.resetPostingsScorer(offsetAndLength.offset());
286280
actualDocs += scorer.visit(knnCollector);
287-
if (knnCollector.earlyTerminated()) {
288-
break;
289-
}
290281
}
291282
}
292283
}

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 33 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerP
6262

6363
static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
6464

65+
private static final float MAX_VISIT_INCREASE_RATIO = 0.1f;
66+
private static final double MAX_VISIT_DECREASE_RATIO = 0.01f;
67+
6568
protected final String field;
6669
protected final float providedVisitRatio;
6770
protected final int k;
@@ -166,79 +169,48 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
166169
visitRatio = providedVisitRatio;
167170
}
168171

169-
// FIXME: pick a reasonable min budget and make sure visitRatio is used appropriately throughout and not pushing the budget to zero
170-
int totalBudget = Math.max(100, (int) (totalDocsWVectors * visitRatio));
171-
172172
List<Callable<TopDocs>> tasks;
173173
if (leafReaderContexts.isEmpty() == false) {
174174
// calculate the affinity of each segment to the query vector
175-
// (need information from each segment: no. of clusters, global centroid, density, parent centroids' scores, etc.)
176175
List<SegmentAffinity> segmentAffinities = calculateSegmentAffinities(leafReaderContexts, getQueryVector(), costs);
176+
segmentAffinities.sort((a, b) -> Double.compare(b.affinityScore(), a.affinityScore()));
177177

178-
// TODO: sort segments by affinity score in descending order, and cut the long tail ?
179178
double[] affinityScores = segmentAffinities.stream()
180179
.map(SegmentAffinity::affinityScore)
181180
.mapToDouble(Double::doubleValue)
182-
.toArray();
183-
184-
double[] filteredAffinityScores = Arrays.stream(affinityScores)
185181
.filter(x -> Double.isNaN(x) == false && Double.isInfinite(x) == false)
186182
.toArray();
187183

188-
final double averageAffinity = Arrays.stream(filteredAffinityScores).average().orElse(0.0);
189-
double variance = Arrays.stream(filteredAffinityScores).map(x -> (x - averageAffinity) * (x - averageAffinity)).sum()
190-
/ filteredAffinityScores.length;
191-
double stdDev = Math.sqrt(variance);
192-
193-
double maxAffinity;
194-
double cutoffAffinity;
195-
double affinityThreshold;
196-
197-
if (stdDev > averageAffinity) {
198-
// adjust calculation when distribution is very skewed
199-
double minAffinity = Arrays.stream(affinityScores).min().orElse(Double.NaN);
200-
maxAffinity = Arrays.stream(affinityScores).max().orElse(Double.NaN);
201-
double lowerAffinity = (minAffinity + averageAffinity) * 0.5;
202-
cutoffAffinity = lowerAffinity * 0.1;
203-
affinityThreshold = (minAffinity + lowerAffinity) * 0.66;
204-
} else {
205-
maxAffinity = averageAffinity + 50 * stdDev;
206-
cutoffAffinity = averageAffinity - 50 * stdDev;
207-
affinityThreshold = averageAffinity + stdDev;
208-
}
184+
double minAffinity = Arrays.stream(affinityScores).min().orElse(Double.NaN);
185+
double maxAffinity = Arrays.stream(affinityScores).max().orElse(Double.NaN);
209186

210-
float maxAdjustments = visitRatio * 1.5f;
187+
double[] normalizedAffinityScores = Arrays.stream(affinityScores)
188+
.map(d -> (d - minAffinity) / (maxAffinity - minAffinity))
189+
.toArray();
211190

212-
if (Double.isNaN(maxAffinity) || Double.isNaN(averageAffinity)) {
191+
if (normalizedAffinityScores.length != segmentAffinities.size()) {
213192
tasks = new ArrayList<>(leafReaderContexts.size());
214193
for (LeafReaderContext context : leafReaderContexts) {
215-
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, visitRatio, Integer.MAX_VALUE));
194+
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, visitRatio));
216195
}
217196
} else {
218197
tasks = new ArrayList<>(segmentAffinities.size());
219-
double scoreVectorsSum = segmentAffinities.stream().map(segmentAffinity -> {
220-
double affinity = Double.isNaN(segmentAffinity.affinityScore) ? maxAffinity : segmentAffinity.affinityScore;
221-
affinity = Math.clamp(affinity, 0.0, maxAffinity);
222-
return affinity * segmentAffinity.numVectors;
223-
}).mapToDouble(Double::doubleValue).sum();
224-
198+
int j = 0;
225199
for (SegmentAffinity segmentAffinity : segmentAffinities) {
226-
double score = segmentAffinity.affinityScore();
227-
if (score < cutoffAffinity) {
228-
continue;
229-
}
230-
float adjustedVisitRatio = adjustVisitRatioForSegment(score, affinityThreshold, maxAdjustments, visitRatio);
231-
LeafReaderContext context = segmentAffinity.context();
200+
double normalizedAffinityScore = normalizedAffinityScores[j];
232201

233-
// distribute the budget according to : budgetᵢ = total_budget × (affinityᵢ × |vectors|ᵢ) / ∑ (affinityⱼ × |vectors|ⱼ)
234-
int segmentBudget = (int) (totalBudget * (score * segmentAffinity.numVectors) / scoreVectorsSum);
202+
float adjustedVisitRatio = (float) Math.clamp(
203+
(visitRatio * (MAX_VISIT_INCREASE_RATIO + normalizedAffinityScore)),
204+
visitRatio * MAX_VISIT_DECREASE_RATIO,
205+
visitRatio * (1 + MAX_VISIT_INCREASE_RATIO)
206+
);
235207

236-
// TODO : should we always grant a min budget for each affine-enough segment
237-
if (segmentBudget > 0) {
238-
tasks.add(
239-
() -> searchLeaf(context, filterWeight, knnCollectorManager, adjustedVisitRatio, Math.max(1, segmentBudget))
240-
);
208+
LeafReaderContext context = segmentAffinity.context();
209+
210+
if (adjustedVisitRatio > 0) {
211+
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, adjustedVisitRatio));
241212
}
213+
j++;
242214
}
243215
}
244216
} else {
@@ -257,21 +229,6 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
257229

258230
abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException;
259231

260-
private float adjustVisitRatioForSegment(double affinityScore, double affinityThreshold, float maxAdjustment, float visitRatio) {
261-
// for high affinity scores, increase visited ratio
262-
if (affinityScore > affinityThreshold) {
263-
int adjustment = (int) Math.ceil((affinityScore - affinityThreshold) * maxAdjustment);
264-
return Math.min(visitRatio * adjustment, visitRatio + maxAdjustment);
265-
}
266-
267-
// for low affinity scores, decrease visited ratio
268-
if (affinityScore <= affinityThreshold) {
269-
return Math.max(visitRatio * 0.5f, 0.01f);
270-
}
271-
272-
return visitRatio;
273-
}
274-
275232
abstract float[] getQueryVector() throws IOException;
276233

277234
private IVFVectorsReader unwrapReader(KnnVectorsReader knnVectorsReader) {
@@ -287,8 +244,7 @@ private IVFVectorsReader unwrapReader(KnnVectorsReader knnVectorsReader) {
287244
return result;
288245
}
289246

290-
private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext> leafReaderContexts, float[] queryVector, int[] costs)
291-
throws IOException {
247+
private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext> leafReaderContexts, float[] queryVector, int[] costs) {
292248
List<SegmentAffinity> segmentAffinities = new ArrayList<>(leafReaderContexts.size());
293249

294250
int i = 0;
@@ -325,25 +281,8 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
325281
int numCentroids = reader.getNumCentroids(fieldInfo);
326282
double centroidDensity = (double) numCentroids / numVectors;
327283

328-
// with larger clusters, global centroid might not be a good representative,
329-
// so we want to include "some" centroids' scores for higher quality estimate
330-
// TODO: tweak the threshold numCentroids here
331-
if (numCentroids > 64) {
332-
float[] centroidScores = reader.getCentroidsScores(
333-
fieldInfo,
334-
numCentroids,
335-
reader.getIvfCentroids(fieldInfo),
336-
queryVector,
337-
numCentroids > 128
338-
);
339-
Arrays.sort(centroidScores);
340-
float first = centroidScores[centroidScores.length - 1];
341-
float second = centroidScores[centroidScores.length - 2];
342-
centroidsScore = (centroidsScore + first + second) / 3;
343-
}
344-
284+
// TODO : we may want to include some actual centroids' scores for higher quality estimate
345285
double affinityScore = centroidsScore * (1 + centroidDensity);
346-
347286
segmentAffinities.add(new SegmentAffinity(context, affinityScore, numVectors));
348287
} else {
349288
segmentAffinities.add(new SegmentAffinity(context, Float.NaN, 0));
@@ -357,14 +296,9 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
357296

358297
private record SegmentAffinity(LeafReaderContext context, double affinityScore, int numVectors) {}
359298

360-
private TopDocs searchLeaf(
361-
LeafReaderContext ctx,
362-
Weight filterWeight,
363-
KnnCollectorManager knnCollectorManager,
364-
float visitRatio,
365-
int visitingBudget
366-
) throws IOException {
367-
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager, visitRatio, visitingBudget);
299+
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager, float visitRatio)
300+
throws IOException {
301+
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager, visitRatio);
368302
IntHashSet dedup = new IntHashSet(results.scoreDocs.length * 4 / 3);
369303
int deduplicateCount = 0;
370304
for (ScoreDoc scoreDoc : results.scoreDocs) {
@@ -384,19 +318,13 @@ private TopDocs searchLeaf(
384318
return new TopDocs(results.totalHits, deduplicatedScoreDocs);
385319
}
386320

387-
TopDocs getLeafResults(
388-
LeafReaderContext ctx,
389-
Weight filterWeight,
390-
KnnCollectorManager knnCollectorManager,
391-
float visitRatio,
392-
int visitingBudget
393-
) throws IOException {
321+
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager, float visitRatio)
322+
throws IOException {
394323
final LeafReader reader = ctx.reader();
395324
final Bits liveDocs = reader.getLiveDocs();
396325

397-
KnnSearchStrategy searchStrategy = new IVFKnnSearchStrategy(visitRatio);
398326
if (filterWeight == null) {
399-
return approximateSearch(ctx, liveDocs, visitingBudget, knnCollectorManager, searchStrategy);
327+
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager, visitRatio);
400328
}
401329

402330
Scorer scorer = filterWeight.scorer(ctx);
@@ -406,15 +334,15 @@ TopDocs getLeafResults(
406334

407335
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
408336
final int cost = acceptDocs.cardinality();
409-
return approximateSearch(ctx, acceptDocs, Math.min(visitingBudget, cost + 1), knnCollectorManager, searchStrategy);
337+
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager, visitRatio);
410338
}
411339

412340
abstract TopDocs approximateSearch(
413341
LeafReaderContext context,
414342
Bits acceptDocs,
415343
int visitedLimit,
416344
KnnCollectorManager knnCollectorManager,
417-
KnnSearchStrategy searchStrategy
345+
float visitRatio
418346
) throws IOException;
419347

420348
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {

server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ protected TopDocs approximateSearch(
9292
Bits acceptDocs,
9393
int visitedLimit,
9494
KnnCollectorManager knnCollectorManager,
95-
KnnSearchStrategy searchStrategy
95+
float visitRatio
9696
) throws IOException {
9797
LeafReader reader = context.reader();
9898
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
@@ -103,7 +103,8 @@ protected TopDocs approximateSearch(
103103
if (floatVectorValues.size() == 0) {
104104
return NO_RESULTS;
105105
}
106-
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, searchStrategy, context);
106+
KnnSearchStrategy strategy = new IVFKnnSearchStrategy(visitRatio);
107+
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, strategy, context);
107108
if (knnCollector == null) {
108109
return NO_RESULTS;
109110
}

0 commit comments

Comments
 (0)