Skip to content

Commit df0210d

Browse files
committed
Include top 2 parent scores into affinity, with larger segments
1 parent 3f4e5fd commit df0210d

File tree

3 files changed

+78
-110
lines changed

3 files changed

+78
-110
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 42 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -48,106 +48,46 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
4848
}
4949

5050
@Override
51-
public ScoredCentroidIterator getScoredCentroidIterator(
52-
FieldInfo fieldInfo,
53-
int numCentroids,
54-
IndexInput centroids,
55-
float[] targetQuery
56-
) throws IOException {
51+
CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
52+
throws IOException {
5753
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
5854
final float globalCentroidDp = fieldEntry.globalCentroidDp();
5955
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
6056
final int[] scratch = new int[targetQuery.length];
57+
float[] targetQueryCopy = ArrayUtil.copyArray(targetQuery);
58+
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
59+
VectorUtil.l2normalize(targetQueryCopy);
60+
}
6161
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
62-
ArrayUtil.copyArray(targetQuery),
62+
targetQueryCopy,
6363
scratch,
64-
(byte) 4,
64+
(byte) 7,
6565
fieldEntry.globalCentroid()
6666
);
6767
final byte[] quantized = new byte[targetQuery.length];
6868
for (int i = 0; i < quantized.length; i++) {
6969
quantized[i] = (byte) scratch[i];
7070
}
71-
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
72-
NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true);
71+
final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension());
7372
centroids.seek(0L);
74-
final float[] centroidCorrectiveValues = new float[3];
75-
for (int i = 0; i < numCentroids; i++) {
76-
final float qcDist = scorer.int4DotProduct(quantized);
77-
centroids.readFloats(centroidCorrectiveValues, 0, 3);
78-
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
79-
// TODO : fix this
80-
float score = Float.NaN;/*int4QuantizedScore(
81-
qcDist,
82-
queryParams,
83-
fieldInfo.getVectorDimension(),
84-
centroidCorrectiveValues,
85-
quantizedCentroidComponentSum,
86-
globalCentroidDp,
87-
fieldInfo.getVectorSimilarityFunction()
88-
);*/
89-
queue.add(i, score);
73+
int numParents = centroids.readVInt();
74+
if (numParents > 0) {
75+
return getCentroidIteratorWithParents(
76+
fieldInfo,
77+
centroids,
78+
numParents,
79+
numCentroids,
80+
scorer,
81+
quantized,
82+
queryParams,
83+
globalCentroidDp
84+
);
9085
}
91-
final long offset = centroids.getFilePointer();
92-
CentroidIterator centroidIterator = new CentroidIterator() {
93-
@Override
94-
public boolean hasNext() {
95-
return queue.size() > 0;
96-
}
97-
98-
@Override
99-
public long nextPostingListOffset() throws IOException {
100-
int centroidOrdinal = queue.pop();
101-
centroids.seek(offset + (long) Long.BYTES * centroidOrdinal);
102-
return centroids.readLong();
103-
}
104-
};
105-
ScoredCentroidIterator scoredCentroidIterator = new ScoredCentroidIterator() {
106-
private float currentTopScore = Float.NEGATIVE_INFINITY;
107-
private long currentTopCentroidOffset = -1;
108-
109-
@Override
110-
public boolean hasNext() {
111-
return centroidIterator.hasNext();
112-
}
113-
114-
@Override
115-
public void scorePostingList(long offset) throws IOException {
116-
117-
}
118-
119-
@Override
120-
public long next() {
121-
return currentTopCentroidOffset;
122-
}
123-
124-
@Override
125-
public long nextPostingListOffset() throws IOException {
126-
return centroidIterator.nextPostingListOffset();
127-
}
128-
129-
public float getCurrentTopScore() {
130-
return currentTopScore;
131-
}
132-
133-
private void updateTopScore(float score, long offset) {
134-
if (score > currentTopScore) {
135-
currentTopScore = score;
136-
currentTopCentroidOffset = offset;
137-
}
138-
}
139-
140-
public long getCurrentTopCentroidOffset() {
141-
return currentTopCentroidOffset;
142-
}
143-
144-
};
145-
return scoredCentroidIterator;
146-
86+
return getCentroidIteratorNoParent(fieldInfo, centroids, numCentroids, scorer, quantized, queryParams, globalCentroidDp);
14787
}
14888

14989
@Override
150-
CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
90+
public float[] getParentCentroidsScores(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
15191
throws IOException {
15292
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
15393
final float globalCentroidDp = fieldEntry.globalCentroidDp();
@@ -163,26 +103,35 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
163103
(byte) 7,
164104
fieldEntry.globalCentroid()
165105
);
166-
final byte[] quantized = new byte[targetQuery.length];
167-
for (int i = 0; i < quantized.length; i++) {
168-
quantized[i] = (byte) scratch[i];
106+
final byte[] quantizedQuery = new byte[targetQuery.length];
107+
for (int i = 0; i < quantizedQuery.length; i++) {
108+
quantizedQuery[i] = (byte) scratch[i];
169109
}
170110
final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension());
171111
centroids.seek(0L);
112+
// score the parents
113+
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
114+
172115
int numParents = centroids.readVInt();
173116
if (numParents > 0) {
174-
return getCentroidIteratorWithParents(
175-
fieldInfo,
176-
centroids,
117+
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
118+
final int maxChildrenSize = centroids.readVInt();
119+
final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
120+
final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1);
121+
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
122+
score(
123+
parentsQueue,
177124
numParents,
178-
numCentroids,
125+
0,
179126
scorer,
180-
quantized,
127+
quantizedQuery,
181128
queryParams,
182-
globalCentroidDp
129+
globalCentroidDp,
130+
fieldInfo.getVectorSimilarityFunction(),
131+
scores
183132
);
184133
}
185-
return getCentroidIteratorNoParent(fieldInfo, centroids, numCentroids, scorer, quantized, queryParams, globalCentroidDp);
134+
return scores;
186135
}
187136

188137
private static CentroidIterator getCentroidIteratorNoParent(

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR
9191
abstract CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
9292
throws IOException;
9393

94+
public abstract float[] getParentCentroidsScores(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
95+
throws IOException;
96+
9497
private static IndexInput openDataInput(
9598
SegmentReadState state,
9699
int versionMeta,
@@ -292,13 +295,6 @@ public void close() throws IOException {
292295
IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters);
293296
}
294297

295-
public abstract ScoredCentroidIterator getScoredCentroidIterator(
296-
FieldInfo fieldInfo,
297-
int numCentroids,
298-
IndexInput centroids,
299-
float[] queryVector
300-
) throws IOException;
301-
302298
protected record FieldEntry(
303299
VectorSimilarityFunction similarityFunction,
304300
VectorEncoding vectorEncoding,
@@ -343,8 +339,8 @@ interface PostingVisitor {
343339
int visit(KnnCollector collector) throws IOException;
344340
}
345341

346-
public IndexInput getIvfCentroids() {
347-
return ivfCentroids;
342+
public IndexInput getIvfCentroids(FieldInfo fieldInfo) throws IOException {
343+
return fields.get(fieldInfo.number).centroidSlice(ivfCentroids);
348344
}
349345

350346
public int getNumCentroids(FieldInfo fieldInfo) {

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import org.apache.lucene.search.Weight;
3838
import org.apache.lucene.search.knn.KnnCollectorManager;
3939
import org.apache.lucene.search.knn.KnnSearchStrategy;
40-
import org.apache.lucene.store.IndexInput;
4140
import org.apache.lucene.util.BitSet;
4241
import org.apache.lucene.util.BitSetIterator;
4342
import org.apache.lucene.util.Bits;
@@ -48,12 +47,14 @@
4847

4948
import java.io.IOException;
5049
import java.util.ArrayList;
50+
import java.util.Arrays;
5151
import java.util.HashMap;
5252
import java.util.List;
5353
import java.util.Map;
5454
import java.util.Objects;
5555
import java.util.concurrent.Callable;
5656

57+
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
5758
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
5859

5960
abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {
@@ -134,6 +135,10 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
134135
// (need information from each segment: no. of clusters, global centroid, density, whatever, ...)
135136
List<SegmentAffinity> segmentAffinities = calculateSegmentAffinities(leafReaderContexts, getQueryVector());
136137

138+
// TODO: sort segments by affinity score in descending order, and cut the long tail ?
139+
// segmentAffinities.sort((a, b) -> Double.compare(b.affinityScore(), a.affinityScore()));
140+
// ...subList(0, (int) (segmentAffinities.size() * 0.99));
141+
137142
// with larger affinity we increase nprobe (and viceversa)
138143
// also sort segments by affinity and eventually filter out the long tail
139144
List<LeafReaderContext> selectedSegments = new ArrayList<>();
@@ -155,6 +160,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
155160
}
156161

157162
// Adjust nProbe based on affinity score
163+
// with larger affinity we increase nprobe (and viceversa)
158164
int adjustedNProbe = adjustNProbeForSegment(score, higher_affinity, lower_affinity, max_adjustment);
159165

160166
// Store the adjusted nProbe value for this segment
@@ -197,7 +203,8 @@ private int adjustNProbeForSegment(double affinityScore, double highThreshold, d
197203

198204
abstract float[] getQueryVector() throws IOException;
199205

200-
private List<IVFVectorsReader.ScoredCentroidIterator> collectIterators(List<LeafReaderContext> leafReaderContexts) throws IOException {
206+
/*
207+
private List<IVFVectorsReader.CentroidIterator> collectIterators(List<LeafReaderContext> leafReaderContexts) throws IOException {
201208
List<IVFVectorsReader.ScoredCentroidIterator> iterators = new ArrayList<>(leafReaderContexts.size());
202209
for (LeafReaderContext context : leafReaderContexts) {
203210
LeafReader leafReader = context.reader();
@@ -210,11 +217,12 @@ private List<IVFVectorsReader.ScoredCentroidIterator> collectIterators(List<Leaf
210217
FieldInfo fieldInfo = leafReader.getFieldInfos().fieldInfo(field);
211218
int numCentroids = reader.getNumCentroids(fieldInfo);
212219
IndexInput centroids = reader.getIvfCentroids();
213-
iterators.add(reader.getScoredCentroidIterator(fieldInfo, numCentroids, centroids, getQueryVector()));
220+
iterators.add(reader.getCentroidIterator(fieldInfo, numCentroids, centroids, getQueryVector()));
214221
}
215222
}
216223
return iterators;
217224
}
225+
*/
218226

219227
private IVFVectorsReader unwrapReader(KnnVectorsReader knnVectorsReader) {
220228
IVFVectorsReader result = null;
@@ -229,29 +237,47 @@ private IVFVectorsReader unwrapReader(KnnVectorsReader knnVectorsReader) {
229237
return result;
230238
}
231239

232-
private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext> leafReaderContexts, float[] queryVector) {
240+
private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext> leafReaderContexts, float[] queryVector)
241+
throws IOException {
233242
List<SegmentAffinity> segmentAffinities = new ArrayList<>(leafReaderContexts.size());
234243

235244
for (LeafReaderContext context : leafReaderContexts) {
236245
LeafReader leafReader = context.reader();
237246
FieldInfo fieldInfo = leafReader.getFieldInfos().fieldInfo(field);
247+
if (fieldInfo == null) {
248+
continue;
249+
}
238250
VectorSimilarityFunction similarityFunction = fieldInfo.getVectorSimilarityFunction();
239251
if (leafReader instanceof SegmentReader segmentReader) {
240252
KnnVectorsReader vectorReader = segmentReader.getVectorReader();
241253
IVFVectorsReader reader = unwrapReader(vectorReader);
242254
if (reader != null) {
243255
float[] globalCentroid = reader.getGlobalCentroid(fieldInfo);
244-
int numCentroids = reader.getNumCentroids(fieldInfo);
245256

257+
if (similarityFunction == COSINE) {
258+
VectorUtil.l2normalize(queryVector);
259+
}
246260
// similarity between query vector and global centroid, higher is better
247261
float globalCentroidScore = similarityFunction.compare(queryVector, globalCentroid);
248262
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
249263
globalCentroidScore = VectorUtil.scaleMaxInnerProductScore(globalCentroidScore);
250264
}
251265

252266
// clusters per vector (< 1), higher is better (better coverage)
267+
int numCentroids = reader.getNumCentroids(fieldInfo);
253268
double centroidDensity = (double) numCentroids / leafReader.numDocs();
254269

270+
if (numCentroids > 64) {
271+
float[] parentCentroidsScores = reader.getParentCentroidsScores(
272+
fieldInfo,
273+
numCentroids,
274+
reader.getIvfCentroids(fieldInfo),
275+
queryVector
276+
);
277+
Arrays.sort(parentCentroidsScores);
278+
globalCentroidScore = (parentCentroidsScores[0] + parentCentroidsScores[1] + globalCentroidScore) / 3;
279+
}
280+
255281
double affinityScore = globalCentroidScore * (1 + centroidDensity);
256282

257283
segmentAffinities.add(new SegmentAffinity(context, affinityScore));
@@ -261,9 +287,6 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
261287
}
262288
}
263289

264-
// TODO: sort segments by affinity score in descending order, and cut the long tail ?
265-
//segmentAffinities.sort((a, b) -> Double.compare(b.affinityScore(), a.affinityScore()));
266-
//...subList(0, (int) (segmentAffinities.size() * 0.99));
267290
return segmentAffinities;
268291
}
269292

0 commit comments

Comments
 (0)