Skip to content

Commit daa8f40

Browse files
committed
work better with unbalanced segments, simplified
1 parent 3d1c6b6 commit daa8f40

File tree

2 files changed

+20
-44
lines changed

2 files changed

+20
-44
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,6 @@ public IndexInput getIvfCentroids(FieldInfo fieldInfo) throws IOException {
353353
return fields.get(fieldInfo.number).centroidSlice(ivfCentroids);
354354
}
355355

356-
public int getNumCentroids(FieldInfo fieldInfo) {
357-
return fields.get(fieldInfo.number).numCentroids;
358-
}
359-
360356
public float[] getGlobalCentroid(FieldInfo fieldInfo) {
361357
return fields.get(fieldInfo.number).globalCentroid;
362358
}

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848

4949
import java.io.IOException;
5050
import java.util.ArrayList;
51-
import java.util.Arrays;
5251
import java.util.Collections;
5352
import java.util.List;
5453
import java.util.Objects;
@@ -62,11 +61,11 @@
6261
abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {
6362

6463
static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
65-
public static final float MIN_VISIT_RATIO_FOR_AFFINITY_ADJUSTMENT = 0.004f;
66-
public static final float MAX_AFFINITY_MULTIPLIER_ADJUSTMENT = 1.1f;
67-
public static final float MIN_AFFINITY_MULTIPLIER_ADJUSTMENT = 0.5f;
68-
public static final float MIN_AFFINITY = 0.001f;
69-
public static final float MAX_AFFINITY = 1f;
64+
private static final float MIN_VISIT_RATIO_FOR_AFFINITY_ADJUSTMENT = 0.004f;
65+
private static final float MAX_AFFINITY_MULTIPLIER_ADJUSTMENT = 1.1f;
66+
private static final float MIN_AFFINITY_MULTIPLIER_ADJUSTMENT = 0.75f;
67+
private static final float MIN_AFFINITY = 0.001f;
68+
private static final float MAX_AFFINITY = 1f;
7069

7170
protected final String field;
7271
protected final float providedVisitRatio;
@@ -179,43 +178,24 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
179178
List<SegmentAffinity> segmentAffinities = calculateSegmentAffinities(leafReaderContexts, getQueryVector(), costs);
180179
segmentAffinities.sort((a, b) -> Double.compare(b.affinityScore(), a.affinityScore()));
181180

182-
double[] affinityScores = segmentAffinities.stream()
183-
.map(SegmentAffinity::affinityScore)
184-
.mapToDouble(Double::doubleValue)
185-
.filter(x -> Double.isNaN(x) == false && Double.isInfinite(x) == false)
186-
.toArray();
187-
188-
double minAffinity = Arrays.stream(affinityScores).min().orElse(Double.NaN);
189-
double maxAffinity = Arrays.stream(affinityScores).max().orElse(Double.NaN);
190-
191-
double[] normalizedAffinityScores = Arrays.stream(affinityScores)
192-
.map(d -> (d - minAffinity) / (maxAffinity - minAffinity))
193-
.toArray();
194-
195-
// TODO : enable affinity optimization for filtered case
196-
if (filterWeight != null
197-
|| normalizedAffinityScores.length != segmentAffinities.size()
198-
|| Double.isNaN(minAffinity)
199-
|| Double.isNaN(maxAffinity)
181+
if (filterWeight != null // TODO : enable affinity optimization for filtered case
200182
|| leafReaderContexts.size() == 1) {
201183
tasks = new ArrayList<>(leafReaderContexts.size());
202184
for (LeafReaderContext context : leafReaderContexts) {
203185
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, visitRatio));
204186
}
205187
} else {
206188
tasks = new ArrayList<>(segmentAffinities.size());
207-
int j = 0;
208189
for (SegmentAffinity segmentAffinity : segmentAffinities) {
209-
double normalizedAffinityScore = normalizedAffinityScores[j];
190+
double affinityScore = segmentAffinity.affinityScore;
210191

211192
float adjustedVisitRatio = adjustVisitRatioForSegment(
212-
normalizedAffinityScore,
213-
normalizedAffinityScores[normalizedAffinityScores.length / 10],
193+
affinityScore,
194+
segmentAffinities.get(segmentAffinities.size() / 10).affinityScore,
214195
visitRatio
215196
);
216197

217198
tasks.add(() -> searchLeaf(segmentAffinity.context(), filterWeight, knnCollectorManager, adjustedVisitRatio));
218-
j++;
219199
}
220200
}
221201
} else {
@@ -241,12 +221,14 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
241221
private float adjustVisitRatioForSegment(double affinityScore, double affinityThreshold, float visitRatio) {
242222
// for high affinity scores, increase visited ratio
243223
if (affinityScore > affinityThreshold) {
244-
return Math.min(visitRatio * MAX_AFFINITY_MULTIPLIER_ADJUSTMENT, MAX_AFFINITY);
224+
double adjustment = Math.min(1 + (affinityScore - affinityThreshold), MAX_AFFINITY_MULTIPLIER_ADJUSTMENT);
225+
return Math.min((float) (visitRatio * adjustment), MAX_AFFINITY);
245226
}
246227

247228
// for low affinity scores, decrease visited ratio
248-
if (affinityScore <= affinityThreshold) {
249-
return Math.max(visitRatio * MIN_AFFINITY_MULTIPLIER_ADJUSTMENT, MIN_AFFINITY);
229+
if (affinityScore < affinityThreshold) {
230+
double adjustment = Math.max(1 - (affinityThreshold - affinityScore), MIN_AFFINITY_MULTIPLIER_ADJUSTMENT);
231+
return (float) Math.max(visitRatio * adjustment, MIN_AFFINITY);
250232
}
251233

252234
return visitRatio;
@@ -298,19 +280,17 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
298280
+ fieldInfo.getVectorDimension()
299281
);
300282
}
301-
// similarity between query vector and global centroid, higher is better
283+
302284
float centroidsScore = similarityFunction.compare(queryVector, globalCentroid);
303285

304-
// clusters per vector (< 1), higher is better (better coverage)
305286
int numVectors = costs[i];
306-
int numCentroids = reader.getNumCentroids(fieldInfo);
307-
double centroidDensity = (double) numCentroids / numVectors;
308287

309288
// TODO : we may want to include some actual centroids' scores for higher quality estimate
310-
double affinityScore = centroidsScore * Math.log10(numVectors) * (1 + centroidDensity);
311-
segmentAffinities.add(new SegmentAffinity(context, affinityScore, numVectors));
289+
double affinityScore = centroidsScore * (Math.log10(numVectors));
290+
291+
segmentAffinities.add(new SegmentAffinity(context, affinityScore));
312292
} else {
313-
segmentAffinities.add(new SegmentAffinity(context, Float.NaN, 0));
293+
segmentAffinities.add(new SegmentAffinity(context, Float.NaN));
314294
}
315295
}
316296
i++;
@@ -319,7 +299,7 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
319299
return segmentAffinities;
320300
}
321301

322-
private record SegmentAffinity(LeafReaderContext context, double affinityScore, int numVectors) {}
302+
private record SegmentAffinity(LeafReaderContext context, double affinityScore) {}
323303

324304
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, IVFCollectorManager knnCollectorManager, float visitRatio)
325305
throws IOException {

0 commit comments

Comments
 (0)