2929
3030import static org .apache .lucene .codecs .lucene102 .Lucene102BinaryQuantizedVectorsFormat .QUERY_BITS ;
3131import static org .apache .lucene .index .VectorSimilarityFunction .COSINE ;
32- import static org .apache .lucene .index .VectorSimilarityFunction .EUCLIDEAN ;
33- import static org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
3432import static org .elasticsearch .index .codec .vectors .BQSpaceUtils .transposeHalfByte ;
3533import static org .elasticsearch .index .codec .vectors .BQVectorUtils .discretize ;
3634import static org .elasticsearch .index .codec .vectors .OptimizedScalarQuantizer .DEFAULT_LAMBDA ;
4139 * brute force and then scores the top ones using the posting list.
4240 */
4341public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats {
44- private static final float FOUR_BIT_SCALE = 1f / ((1 << 4 ) - 1 );
42+
43+ // The percentage of centroids that are scored to keep recall
44+ public static final double CENTROID_SAMPLING_PERCENTAGE = 0.2 ;
4545
4646 public DefaultIVFVectorsReader (SegmentReadState state , FlatVectorsReader rawVectorsReader ) throws IOException {
4747 super (state , rawVectorsReader );
@@ -54,8 +54,12 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
5454 final float globalCentroidDp = fieldEntry .globalCentroidDp ();
5555 final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ());
5656 final int [] scratch = new int [targetQuery .length ];
57+ float [] targetQueryCopy = ArrayUtil .copyArray (targetQuery );
58+ if (fieldInfo .getVectorSimilarityFunction () == COSINE ) {
59+ VectorUtil .l2normalize (targetQueryCopy );
60+ }
5761 final OptimizedScalarQuantizer .QuantizationResult queryParams = scalarQuantizer .scalarQuantize (
58- ArrayUtil . copyArray ( targetQuery ) ,
62+ targetQueryCopy ,
5963 scratch ,
6064 (byte ) 4 ,
6165 fieldEntry .globalCentroid ()
@@ -65,67 +69,227 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
6569 quantized [i ] = (byte ) scratch [i ];
6670 }
6771 final ES91Int4VectorsScorer scorer = ESVectorUtil .getES91Int4VectorsScorer (centroids , fieldInfo .getVectorDimension ());
68- NeighborQueue queue = new NeighborQueue (fieldEntry .numCentroids (), true );
6972 centroids .seek (0L );
70- final float [] centroidCorrectiveValues = new float [3 ];
71- for (int i = 0 ; i < numCentroids ; i ++) {
72- final float qcDist = scorer .int4DotProduct (quantized );
73- centroids .readFloats (centroidCorrectiveValues , 0 , 3 );
74- final int quantizedCentroidComponentSum = Short .toUnsignedInt (centroids .readShort ());
75- float score = int4QuantizedScore (
76- qcDist ,
73+ int numParents = centroids .readVInt ();
74+ if (numParents > 0 ) {
75+ return getCentroidIteratorWithParents (
76+ fieldInfo ,
77+ centroids ,
78+ numParents ,
79+ numCentroids ,
80+ scorer ,
81+ quantized ,
7782 queryParams ,
78- fieldInfo .getVectorDimension (),
79- centroidCorrectiveValues ,
80- quantizedCentroidComponentSum ,
81- globalCentroidDp ,
82- fieldInfo .getVectorSimilarityFunction ()
83+ globalCentroidDp
8384 );
84- queue .add (i , score );
8585 }
86- final long offset = centroids .getFilePointer ();
86+ return getCentroidIteratorNoParent (fieldInfo , centroids , numCentroids , scorer , quantized , queryParams , globalCentroidDp );
87+ }
88+
89+ private static CentroidIterator getCentroidIteratorNoParent (
90+ FieldInfo fieldInfo ,
91+ IndexInput centroids ,
92+ int numCentroids ,
93+ ES91Int4VectorsScorer scorer ,
94+ byte [] quantizeQuery ,
95+ OptimizedScalarQuantizer .QuantizationResult queryParams ,
96+ float globalCentroidDp
97+ ) throws IOException {
98+ final NeighborQueue neighborQueue = new NeighborQueue (numCentroids , true );
99+ score (
100+ neighborQueue ,
101+ numCentroids ,
102+ 0 ,
103+ scorer ,
104+ quantizeQuery ,
105+ queryParams ,
106+ globalCentroidDp ,
107+ fieldInfo .getVectorSimilarityFunction (),
108+ new float [ES91Int4VectorsScorer .BULK_SIZE ]
109+ );
110+ long offset = centroids .getFilePointer ();
87111 return new CentroidIterator () {
88112 @ Override
89113 public boolean hasNext () {
90- return queue .size () > 0 ;
114+ return neighborQueue .size () > 0 ;
91115 }
92116
93117 @ Override
94118 public long nextPostingListOffset () throws IOException {
95- int centroidOrdinal = queue .pop ();
119+ int centroidOrdinal = neighborQueue .pop ();
96120 centroids .seek (offset + (long ) Long .BYTES * centroidOrdinal );
97121 return centroids .readLong ();
98122 }
99123 };
100124 }
101125
102- // TODO can we do this in off-heap blocks?
103- private float int4QuantizedScore (
104- float qcDist ,
126+ private static CentroidIterator getCentroidIteratorWithParents (
127+ FieldInfo fieldInfo ,
128+ IndexInput centroids ,
129+ int numParents ,
130+ int numCentroids ,
131+ ES91Int4VectorsScorer scorer ,
132+ byte [] quantizeQuery ,
133+ OptimizedScalarQuantizer .QuantizationResult queryParams ,
134+ float globalCentroidDp
135+ ) throws IOException {
136+ // build the three queues we are going to use
137+ final NeighborQueue parentsQueue = new NeighborQueue (numParents , true );
138+ final int maxChildrenSize = centroids .readVInt ();
139+ final NeighborQueue currentParentQueue = new NeighborQueue (maxChildrenSize , true );
140+ final int bufferSize = (int ) Math .max (numCentroids * CENTROID_SAMPLING_PERCENTAGE , 1 );
141+ final NeighborQueue neighborQueue = new NeighborQueue (bufferSize , true );
142+ // score the parents
143+ final float [] scores = new float [ES91Int4VectorsScorer .BULK_SIZE ];
144+ score (
145+ parentsQueue ,
146+ numParents ,
147+ 0 ,
148+ scorer ,
149+ quantizeQuery ,
150+ queryParams ,
151+ globalCentroidDp ,
152+ fieldInfo .getVectorSimilarityFunction (),
153+ scores
154+ );
155+ final long centroidQuantizeSize = fieldInfo .getVectorDimension () + 3 * Float .BYTES + Short .BYTES ;
156+ final long offset = centroids .getFilePointer ();
157+ final long childrenOffset = offset + (long ) Long .BYTES * numParents ;
158+ // populate the children's queue by reading parents one by one
159+ while (parentsQueue .size () > 0 && neighborQueue .size () < bufferSize ) {
160+ final int pop = parentsQueue .pop ();
161+ populateOneChildrenGroup (
162+ currentParentQueue ,
163+ centroids ,
164+ offset + 2L * Integer .BYTES * pop ,
165+ childrenOffset ,
166+ centroidQuantizeSize ,
167+ fieldInfo ,
168+ scorer ,
169+ quantizeQuery ,
170+ queryParams ,
171+ globalCentroidDp ,
172+ scores
173+ );
174+ while (currentParentQueue .size () > 0 && neighborQueue .size () < bufferSize ) {
175+ final float score = currentParentQueue .topScore ();
176+ final int children = currentParentQueue .pop ();
177+ neighborQueue .add (children , score );
178+ }
179+ }
180+ final long childrenFileOffsets = childrenOffset + centroidQuantizeSize * numCentroids ;
181+ return new CentroidIterator () {
182+ @ Override
183+ public boolean hasNext () {
184+ return neighborQueue .size () > 0 ;
185+ }
186+
187+ @ Override
188+ public long nextPostingListOffset () throws IOException {
189+ int centroidOrdinal = neighborQueue .pop ();
190+ updateQueue (); // add one children if available so the queue remains fully populated
191+ centroids .seek (childrenFileOffsets + (long ) Long .BYTES * centroidOrdinal );
192+ return centroids .readLong ();
193+ }
194+
195+ private void updateQueue () throws IOException {
196+ if (currentParentQueue .size () > 0 ) {
197+ // add a children from the current parent queue
198+ float score = currentParentQueue .topScore ();
199+ int children = currentParentQueue .pop ();
200+ neighborQueue .add (children , score );
201+ } else if (parentsQueue .size () > 0 ) {
202+ // add a new parent from the current parent queue
203+ int pop = parentsQueue .pop ();
204+ populateOneChildrenGroup (
205+ currentParentQueue ,
206+ centroids ,
207+ offset + 2L * Integer .BYTES * pop ,
208+ childrenOffset ,
209+ centroidQuantizeSize ,
210+ fieldInfo ,
211+ scorer ,
212+ quantizeQuery ,
213+ queryParams ,
214+ globalCentroidDp ,
215+ scores
216+ );
217+ updateQueue ();
218+ }
219+ }
220+ };
221+ }
222+
223+ private static void populateOneChildrenGroup (
224+ NeighborQueue neighborQueue ,
225+ IndexInput centroids ,
226+ long parentOffset ,
227+ long childrenOffset ,
228+ long centroidQuantizeSize ,
229+ FieldInfo fieldInfo ,
230+ ES91Int4VectorsScorer scorer ,
231+ byte [] quantizeQuery ,
232+ OptimizedScalarQuantizer .QuantizationResult queryParams ,
233+ float globalCentroidDp ,
234+ float [] scores
235+ ) throws IOException {
236+ centroids .seek (parentOffset );
237+ int childrenOrdinal = centroids .readInt ();
238+ int numChildren = centroids .readInt ();
239+ centroids .seek (childrenOffset + centroidQuantizeSize * childrenOrdinal );
240+ score (
241+ neighborQueue ,
242+ numChildren ,
243+ childrenOrdinal ,
244+ scorer ,
245+ quantizeQuery ,
246+ queryParams ,
247+ globalCentroidDp ,
248+ fieldInfo .getVectorSimilarityFunction (),
249+ scores
250+ );
251+ }
252+
253+ private static void score (
254+ NeighborQueue neighborQueue ,
255+ int size ,
256+ int scoresOffset ,
257+ ES91Int4VectorsScorer scorer ,
258+ byte [] quantizeQuery ,
105259 OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
106- int dims ,
107- float [] targetCorrections ,
108- int targetComponentSum ,
109260 float centroidDp ,
110- VectorSimilarityFunction similarityFunction
111- ) {
112- float ax = targetCorrections [0 ];
113- float lx = (targetCorrections [1 ] - ax ) * FOUR_BIT_SCALE ;
114- float ay = queryCorrections .lowerInterval ();
115- float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
116- float y1 = queryCorrections .quantizedComponentSum ();
117- float score = ax * ay * dims + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
118- if (similarityFunction == EUCLIDEAN ) {
119- score = queryCorrections .additionalCorrection () + targetCorrections [2 ] - 2 * score ;
120- return Math .max (1 / (1f + score ), 0 );
121- } else {
122- // For cosine and max inner product, we need to apply the additional correction, which is
123- // assumed to be the non-centered dot-product between the vector and the centroid
124- score += queryCorrections .additionalCorrection () + targetCorrections [2 ] - centroidDp ;
125- if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
126- return VectorUtil .scaleMaxInnerProductScore (score );
261+ VectorSimilarityFunction similarityFunction ,
262+ float [] scores
263+ ) throws IOException {
264+ int limit = size - ES91Int4VectorsScorer .BULK_SIZE + 1 ;
265+ int i = 0 ;
266+ for (; i < limit ; i += ES91Int4VectorsScorer .BULK_SIZE ) {
267+ scorer .scoreBulk (
268+ quantizeQuery ,
269+ queryCorrections .lowerInterval (),
270+ queryCorrections .upperInterval (),
271+ queryCorrections .quantizedComponentSum (),
272+ queryCorrections .additionalCorrection (),
273+ similarityFunction ,
274+ centroidDp ,
275+ scores
276+ );
277+ for (int j = 0 ; j < ES91Int4VectorsScorer .BULK_SIZE ; j ++) {
278+ neighborQueue .add (scoresOffset + i + j , scores [j ]);
127279 }
128- return Math .max ((1f + score ) / 2f , 0 );
280+ }
281+
282+ for (; i < size ; i ++) {
283+ float score = scorer .score (
284+ quantizeQuery ,
285+ queryCorrections .lowerInterval (),
286+ queryCorrections .upperInterval (),
287+ queryCorrections .quantizedComponentSum (),
288+ queryCorrections .additionalCorrection (),
289+ similarityFunction ,
290+ centroidDp
291+ );
292+ neighborQueue .add (scoresOffset + i , score );
129293 }
130294 }
131295
0 commit comments