3737import org .apache .lucene .search .Weight ;
3838import org .apache .lucene .search .knn .KnnCollectorManager ;
3939import org .apache .lucene .search .knn .KnnSearchStrategy ;
40- import org .apache .lucene .store .IndexInput ;
4140import org .apache .lucene .util .BitSet ;
4241import org .apache .lucene .util .BitSetIterator ;
4342import org .apache .lucene .util .Bits ;
4847
4948import java .io .IOException ;
5049import java .util .ArrayList ;
50+ import java .util .Arrays ;
5151import java .util .HashMap ;
5252import java .util .List ;
5353import java .util .Map ;
5454import java .util .Objects ;
5555import java .util .concurrent .Callable ;
5656
57+ import static org .apache .lucene .index .VectorSimilarityFunction .COSINE ;
5758import static org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
5859
5960abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {
@@ -134,6 +135,10 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
134135 // (need information from each segment: no. of clusters, global centroid, density, whatever, ...)
135136 List <SegmentAffinity > segmentAffinities = calculateSegmentAffinities (leafReaderContexts , getQueryVector ());
136137
138+ // TODO: sort segments by affinity score in descending order, and cut the long tail ?
139+ // segmentAffinities.sort((a, b) -> Double.compare(b.affinityScore(), a.affinityScore()));
140+ // ...subList(0, (int) (segmentAffinities.size() * 0.99));
141+
137142 // with larger affinity we increase nprobe (and viceversa)
138143 // also sort segments by affinity and eventually filter out the long tail
139144 List <LeafReaderContext > selectedSegments = new ArrayList <>();
@@ -155,6 +160,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
155160 }
156161
157162 // Adjust nProbe based on affinity score
163+ // with larger affinity we increase nprobe (and viceversa)
158164 int adjustedNProbe = adjustNProbeForSegment (score , higher_affinity , lower_affinity , max_adjustment );
159165
160166 // Store the adjusted nProbe value for this segment
@@ -197,7 +203,8 @@ private int adjustNProbeForSegment(double affinityScore, double highThreshold, d
197203
198204 abstract float [] getQueryVector () throws IOException ;
199205
200- private List <IVFVectorsReader .ScoredCentroidIterator > collectIterators (List <LeafReaderContext > leafReaderContexts ) throws IOException {
206+ /*
207+ private List<IVFVectorsReader.CentroidIterator> collectIterators(List<LeafReaderContext> leafReaderContexts) throws IOException {
201208 List<IVFVectorsReader.ScoredCentroidIterator> iterators = new ArrayList<>(leafReaderContexts.size());
202209 for (LeafReaderContext context : leafReaderContexts) {
203210 LeafReader leafReader = context.reader();
@@ -210,11 +217,12 @@ private List<IVFVectorsReader.ScoredCentroidIterator> collectIterators(List<Leaf
210217 FieldInfo fieldInfo = leafReader.getFieldInfos().fieldInfo(field);
211218 int numCentroids = reader.getNumCentroids(fieldInfo);
212219 IndexInput centroids = reader.getIvfCentroids();
213- iterators .add (reader .getScoredCentroidIterator (fieldInfo , numCentroids , centroids , getQueryVector ()));
220+ iterators.add(reader.getCentroidIterator (fieldInfo, numCentroids, centroids, getQueryVector()));
214221 }
215222 }
216223 return iterators;
217224 }
225+ */
218226
219227 private IVFVectorsReader unwrapReader (KnnVectorsReader knnVectorsReader ) {
220228 IVFVectorsReader result = null ;
@@ -229,29 +237,47 @@ private IVFVectorsReader unwrapReader(KnnVectorsReader knnVectorsReader) {
229237 return result ;
230238 }
231239
232- private List <SegmentAffinity > calculateSegmentAffinities (List <LeafReaderContext > leafReaderContexts , float [] queryVector ) {
240+ private List <SegmentAffinity > calculateSegmentAffinities (List <LeafReaderContext > leafReaderContexts , float [] queryVector )
241+ throws IOException {
233242 List <SegmentAffinity > segmentAffinities = new ArrayList <>(leafReaderContexts .size ());
234243
235244 for (LeafReaderContext context : leafReaderContexts ) {
236245 LeafReader leafReader = context .reader ();
237246 FieldInfo fieldInfo = leafReader .getFieldInfos ().fieldInfo (field );
247+ if (fieldInfo == null ) {
248+ continue ;
249+ }
238250 VectorSimilarityFunction similarityFunction = fieldInfo .getVectorSimilarityFunction ();
239251 if (leafReader instanceof SegmentReader segmentReader ) {
240252 KnnVectorsReader vectorReader = segmentReader .getVectorReader ();
241253 IVFVectorsReader reader = unwrapReader (vectorReader );
242254 if (reader != null ) {
243255 float [] globalCentroid = reader .getGlobalCentroid (fieldInfo );
244- int numCentroids = reader .getNumCentroids (fieldInfo );
245256
257+ if (similarityFunction == COSINE ) {
258+ VectorUtil .l2normalize (queryVector );
259+ }
246260 // similarity between query vector and global centroid, higher is better
247261 float globalCentroidScore = similarityFunction .compare (queryVector , globalCentroid );
248262 if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
249263 globalCentroidScore = VectorUtil .scaleMaxInnerProductScore (globalCentroidScore );
250264 }
251265
252266 // clusters per vector (< 1), higher is better (better coverage)
267+ int numCentroids = reader .getNumCentroids (fieldInfo );
253268 double centroidDensity = (double ) numCentroids / leafReader .numDocs ();
254269
270+ if (numCentroids > 64 ) {
271+ float [] parentCentroidsScores = reader .getParentCentroidsScores (
272+ fieldInfo ,
273+ numCentroids ,
274+ reader .getIvfCentroids (fieldInfo ),
275+ queryVector
276+ );
277+ Arrays .sort (parentCentroidsScores );
278+ globalCentroidScore = (parentCentroidsScores [0 ] + parentCentroidsScores [1 ] + globalCentroidScore ) / 3 ;
279+ }
280+
255281 double affinityScore = globalCentroidScore * (1 + centroidDensity );
256282
257283 segmentAffinities .add (new SegmentAffinity (context , affinityScore ));
@@ -261,9 +287,6 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
261287 }
262288 }
263289
264- // TODO: sort segments by affinity score in descending order, and cut the long tail ?
265- //segmentAffinities.sort((a, b) -> Double.compare(b.affinityScore(), a.affinityScore()));
266- //...subList(0, (int) (segmentAffinities.size() * 0.99));
267290 return segmentAffinities ;
268291 }
269292
0 commit comments