Skip to content

Commit b097c20

Browse files
committed
include (parent) centroids scores for affinity, with larger segments
1 parent 8cbad8f commit b097c20

File tree

3 files changed

+31
-17
lines changed

3 files changed

+31
-17
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
8787
}
8888

8989
@Override
90-
public float[] getParentCentroidsScores(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
90+
public float[] getCentroidsScores(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery, boolean parents)
9191
throws IOException {
9292
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
9393
final float globalCentroidDp = fieldEntry.globalCentroidDp();
@@ -109,11 +109,11 @@ public float[] getParentCentroidsScores(FieldInfo fieldInfo, int numCentroids, I
109109
}
110110
final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension());
111111
centroids.seek(0L);
112-
// score the parents
112+
// final scores
113113
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
114114

115115
int numParents = centroids.readVInt();
116-
if (numParents > 0) {
116+
if (parents && numParents > 0) {
117117
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
118118
final int maxChildrenSize = centroids.readVInt();
119119
final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
@@ -130,6 +130,19 @@ public float[] getParentCentroidsScores(FieldInfo fieldInfo, int numCentroids, I
130130
fieldInfo.getVectorSimilarityFunction(),
131131
scores
132132
);
133+
} else {
134+
final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true);
135+
score(
136+
neighborQueue,
137+
numCentroids,
138+
0,
139+
scorer,
140+
quantizedQuery,
141+
queryParams,
142+
globalCentroidDp,
143+
fieldInfo.getVectorSimilarityFunction(),
144+
scores
145+
);
133146
}
134147
return scores;
135148
}

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR
9191
abstract CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
9292
throws IOException;
9393

94-
public abstract float[] getParentCentroidsScores(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
94+
public abstract float[] getCentroidsScores(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target, boolean parents)
9595
throws IOException;
9696

9797
private static IndexInput openDataInput(

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,17 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
132132
List<LeafReaderContext> leafReaderContexts = reader.leaves();
133133

134134
// calculate the affinity of each segment to the query vector
135-
// (need information from each segment: no. of clusters, global centroid, density, whatever, ...)
135+
// (need information from each segment: no. of clusters, global centroid, density, parent centroids' scores, etc.)
136136
List<SegmentAffinity> segmentAffinities = calculateSegmentAffinities(leafReaderContexts, getQueryVector());
137137

138138
// TODO: sort segments by affinity score in descending order, and cut the long tail ?
139-
140-
// with larger affinity we increase nprobe (and viceversa)
141-
// also sort segments by affinity and eventually filter out the long tail
139+
142140
List<LeafReaderContext> selectedSegments = new ArrayList<>();
143141

144-
// TODO : are these magic numbers ?
145-
double cutoff_affinity = 0.01; // minimum affinity score for a segment to be considered
146-
double higher_affinity = 0.6; // min affinity for increasing nProbe
147-
double lower_affinity = 0.6; // max affinity for decreasing nProbe
142+
double cutoff_affinity = 0.3; // minimum affinity score for a segment to be considered
143+
double higher_affinity = 0.7; // min affinity for increasing nProbe
144+
double lower_affinity = 0.6 ; // max affinity for decreasing nProbe
145+
148146
int max_adjustment = 20;
149147

150148
Map<LeafReaderContext, Integer> segmentNProbeMap = new HashMap<>();
@@ -244,15 +242,18 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
244242
int numCentroids = reader.getNumCentroids(fieldInfo);
245243
double centroidDensity = (double) numCentroids / leafReader.numDocs();
246244

247-
if (numCentroids > 64) {
248-
float[] parentCentroidsScores = reader.getParentCentroidsScores(
245+
// include some centroids' scores
246+
if (numCentroids > 32) {
247+
float[] centroidScores = reader.getCentroidsScores(
249248
fieldInfo,
250249
numCentroids,
251250
reader.getIvfCentroids(fieldInfo),
252-
queryVector
251+
queryVector,
252+
numCentroids > 64
253253
);
254-
Arrays.sort(parentCentroidsScores);
255-
globalCentroidScore = (parentCentroidsScores[0] + parentCentroidsScores[1] + globalCentroidScore) / 3;
254+
Arrays.sort(centroidScores);
255+
globalCentroidScore = (globalCentroidScore + centroidScores[centroidScores.length - 1]
256+
+ centroidScores[centroidScores.length - 2]) / 3;
256257
}
257258

258259
double affinityScore = globalCentroidScore * (1 + centroidDensity);

0 commit comments

Comments
 (0)