2929
3030import static org .apache .lucene .codecs .lucene102 .Lucene102BinaryQuantizedVectorsFormat .QUERY_BITS ;
3131import static org .apache .lucene .index .VectorSimilarityFunction .COSINE ;
32- import static org .apache .lucene .index .VectorSimilarityFunction .EUCLIDEAN ;
33- import static org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
3432import static org .elasticsearch .index .codec .vectors .BQSpaceUtils .transposeHalfByte ;
3533import static org .elasticsearch .index .codec .vectors .BQVectorUtils .discretize ;
3634import static org .elasticsearch .index .codec .vectors .OptimizedScalarQuantizer .DEFAULT_LAMBDA ;
4139 * brute force and then scores the top ones using the posting list.
4240 */
4341public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats {
44- private static final float FOUR_BIT_SCALE = 1f / ((1 << 4 ) - 1 );
42+
43+ // The percentage of centroids that are scored to keep recall
44+ public static final double CENTROID_SAMPLING_PERCENTAGE = 0.1 ;
4545
4646 public DefaultIVFVectorsReader (SegmentReadState state , FlatVectorsReader rawVectorsReader ) throws IOException {
4747 super (state , rawVectorsReader );
@@ -54,8 +54,12 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
5454 final float globalCentroidDp = fieldEntry .globalCentroidDp ();
5555 final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ());
5656 final int [] scratch = new int [targetQuery .length ];
57+ float [] targetQueryCpoy = ArrayUtil .copyArray (targetQuery );
58+ if (fieldInfo .getVectorSimilarityFunction () == COSINE ) {
59+ VectorUtil .l2normalize (targetQueryCpoy );
60+ }
5761 final OptimizedScalarQuantizer .QuantizationResult queryParams = scalarQuantizer .scalarQuantize (
58- ArrayUtil . copyArray ( targetQuery ) ,
62+ targetQueryCpoy ,
5963 scratch ,
6064 (byte ) 4 ,
6165 fieldEntry .globalCentroid ()
@@ -65,68 +69,265 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
6569 quantized [i ] = (byte ) scratch [i ];
6670 }
6771 final ES91Int4VectorsScorer scorer = ESVectorUtil .getES91Int4VectorsScorer (centroids , fieldInfo .getVectorDimension ());
68- NeighborQueue queue = new NeighborQueue (fieldEntry .numCentroids (), true );
6972 centroids .seek (0L );
73+ int numParents = centroids .readVInt ();
74+ if (numParents > 0 ) {
75+ return getCentroidIteratorWithParents (
76+ fieldInfo ,
77+ centroids ,
78+ numParents ,
79+ numCentroids ,
80+ scorer ,
81+ quantized ,
82+ queryParams ,
83+ globalCentroidDp
84+ );
85+ }
86+ return getCentroidIteratorNoParent (fieldInfo , centroids , numCentroids , scorer , quantized , queryParams , globalCentroidDp );
87+ }
88+
89+ private CentroidIterator getCentroidIteratorNoParent (
90+ FieldInfo fieldInfo ,
91+ IndexInput centroids ,
92+ int numCentroids ,
93+ ES91Int4VectorsScorer scorer ,
94+ byte [] quantizeQuery ,
95+ OptimizedScalarQuantizer .QuantizationResult queryParams ,
96+ float globalCentroidDp
97+ ) throws IOException {
98+ final NeighborQueue neighborQueue = new NeighborQueue (numCentroids , true );
99+ int4QuantizedScoreBulk (
100+ neighborQueue ,
101+ centroids ,
102+ numCentroids ,
103+ 0 ,
104+ scorer ,
105+ quantizeQuery ,
106+ queryParams ,
107+ new float [3 ], // targetCorrections
108+ globalCentroidDp ,
109+ fieldInfo .getVectorSimilarityFunction (),
110+ new float [ES91Int4VectorsScorer .BULK_SIZE ]
111+ );
112+ long offset = centroids .getFilePointer ();
113+ return new CentroidIterator () {
114+ @ Override
115+ public boolean hasNext () {
116+ return neighborQueue .size () > 0 ;
117+ }
118+
119+ @ Override
120+ public long nextPostingListOffset () throws IOException {
121+ int centroidOrdinal = neighborQueue .pop ();
122+ centroids .seek (offset + (long ) Long .BYTES * centroidOrdinal );
123+ return centroids .readLong ();
124+ }
125+ };
126+ }
127+
128+ private CentroidIterator getCentroidIteratorWithParents (
129+ FieldInfo fieldInfo ,
130+ IndexInput centroids ,
131+ int numParents ,
132+ int numCentroids ,
133+ ES91Int4VectorsScorer scorer ,
134+ byte [] quantizeQuery ,
135+ OptimizedScalarQuantizer .QuantizationResult queryParams ,
136+ float globalCentroidDp
137+ ) throws IOException {
138+ final int maxChildrenSize = centroids .readVInt ();
139+ final NeighborQueue parentsQueue = new NeighborQueue (numParents , true );
140+ final float [] scores = new float [ES91Int4VectorsScorer .BULK_SIZE ];
70141 final float [] centroidCorrectiveValues = new float [3 ];
71- for (int i = 0 ; i < numCentroids ; i ++) {
72- final float qcDist = scorer .int4DotProduct (quantized );
73- centroids .readFloats (centroidCorrectiveValues , 0 , 3 );
74- final int quantizedCentroidComponentSum = Short .toUnsignedInt (centroids .readShort ());
75- float score = int4QuantizedScore (
76- qcDist ,
142+ int4QuantizedScoreBulk (
143+ parentsQueue ,
144+ centroids ,
145+ numParents ,
146+ 0 ,
147+ scorer ,
148+ quantizeQuery ,
149+ queryParams ,
150+ centroidCorrectiveValues , // targetCorrections
151+ globalCentroidDp ,
152+ fieldInfo .getVectorSimilarityFunction (),
153+ scores
154+ );
155+ final int bufferSize = (int ) Math .max (numCentroids * CENTROID_SAMPLING_PERCENTAGE , 1 );
156+ long centroidQuantizeSize = fieldInfo .getVectorDimension () + 3 * Float .BYTES + Short .BYTES ;
157+ long offset = centroids .getFilePointer ();
158+ long childrenOffset = offset + (long ) Long .BYTES * numParents ;
159+ NeighborQueue currentParentQueue = new NeighborQueue (maxChildrenSize , true );
160+ NeighborQueue neighborQueue = new NeighborQueue (bufferSize , true );
161+ while (parentsQueue .size () > 0 && neighborQueue .size () < bufferSize ) {
162+ int pop = parentsQueue .pop ();
163+ populateOneChildrenGroup (
164+ currentParentQueue ,
165+ centroids ,
166+ offset + 2L * Integer .BYTES * pop ,
167+ childrenOffset ,
168+ centroidQuantizeSize ,
169+ fieldInfo ,
170+ scorer ,
171+ quantizeQuery ,
77172 queryParams ,
78- fieldInfo .getVectorDimension (),
79173 centroidCorrectiveValues ,
80- quantizedCentroidComponentSum ,
81174 globalCentroidDp ,
82- fieldInfo . getVectorSimilarityFunction ()
175+ scores
83176 );
84- queue .add (i , score );
177+ while (currentParentQueue .size () > 0 && neighborQueue .size () < bufferSize ) {
178+ float score = currentParentQueue .topScore ();
179+ int children = currentParentQueue .pop ();
180+ neighborQueue .add (children , score );
181+ }
85182 }
86- final long offset = centroids .getFilePointer ();
183+ long childrenFileOffsets = childrenOffset + centroidQuantizeSize * numCentroids ;
184+
87185 return new CentroidIterator () {
88186 @ Override
89187 public boolean hasNext () {
90- return queue .size () > 0 ;
188+ return neighborQueue .size () > 0 ;
91189 }
92190
93191 @ Override
94192 public long nextPostingListOffset () throws IOException {
95- int centroidOrdinal = queue .pop ();
96- centroids .seek (offset + (long ) Long .BYTES * centroidOrdinal );
193+ int centroidOrdinal = neighborQueue .pop ();
194+ updateQueue ();
195+ centroids .seek (childrenFileOffsets + (long ) Long .BYTES * centroidOrdinal );
97196 return centroids .readLong ();
98197 }
198+
199+ private void updateQueue () throws IOException {
200+ if (currentParentQueue .size () > 0 ) {
201+ float score = currentParentQueue .topScore ();
202+ int children = currentParentQueue .pop ();
203+ neighborQueue .add (children , score );
204+ } else {
205+ if (parentsQueue .size () > 0 ) {
206+ int pop = parentsQueue .pop ();
207+ populateOneChildrenGroup (
208+ currentParentQueue ,
209+ centroids ,
210+ offset + 2L * Integer .BYTES * pop ,
211+ childrenOffset ,
212+ centroidQuantizeSize ,
213+ fieldInfo ,
214+ scorer ,
215+ quantizeQuery ,
216+ queryParams ,
217+ centroidCorrectiveValues ,
218+ globalCentroidDp ,
219+ scores
220+ );
221+ updateQueue ();
222+ }
223+ }
224+ }
99225 };
100226 }
101227
102- // TODO can we do this in off-heap blocks?
103- private float int4QuantizedScore (
104- float qcDist ,
228+ private void populateOneChildrenGroup (
229+ NeighborQueue neighborQueue ,
230+ IndexInput centroids ,
231+ long parentOffset ,
232+ long childrenOffset ,
233+ long centroidQuantizeSize ,
234+ FieldInfo fieldInfo ,
235+ ES91Int4VectorsScorer scorer ,
236+ byte [] quantizeQuery ,
237+ OptimizedScalarQuantizer .QuantizationResult queryParams ,
238+ float [] targetCorrections ,
239+ float globalCentroidDp ,
240+ float [] scores
241+ ) throws IOException {
242+ centroids .seek (parentOffset );
243+ int childrenOrdinal = centroids .readInt ();
244+ int numChildren = centroids .readInt ();
245+ centroids .seek (childrenOffset + centroidQuantizeSize * childrenOrdinal );
246+ int4QuantizedScoreBulk (
247+ neighborQueue ,
248+ centroids ,
249+ numChildren ,
250+ childrenOrdinal ,
251+ scorer ,
252+ quantizeQuery ,
253+ queryParams ,
254+ targetCorrections ,
255+ globalCentroidDp ,
256+ fieldInfo .getVectorSimilarityFunction (),
257+ scores
258+ );
259+ }
260+
261+ private void int4QuantizedScoreBulk (
262+ NeighborQueue neighborQueue ,
263+ IndexInput centroids ,
264+ int size ,
265+ int scoresOffset ,
266+ ES91Int4VectorsScorer scorer ,
267+ byte [] quantizeQuery ,
105268 OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
106- int dims ,
107269 float [] targetCorrections ,
108- int targetComponentSum ,
109270 float centroidDp ,
110- VectorSimilarityFunction similarityFunction
111- ) {
112- float ax = targetCorrections [0 ];
113- float lx = (targetCorrections [1 ] - ax ) * FOUR_BIT_SCALE ;
114- float ay = queryCorrections .lowerInterval ();
115- float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
116- float y1 = queryCorrections .quantizedComponentSum ();
117- float score = ax * ay * dims + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
118- if (similarityFunction == EUCLIDEAN ) {
119- score = queryCorrections .additionalCorrection () + targetCorrections [2 ] - 2 * score ;
120- return Math .max (1 / (1f + score ), 0 );
121- } else {
122- // For cosine and max inner product, we need to apply the additional correction, which is
123- // assumed to be the non-centered dot-product between the vector and the centroid
124- score += queryCorrections .additionalCorrection () + targetCorrections [2 ] - centroidDp ;
125- if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
126- return VectorUtil .scaleMaxInnerProductScore (score );
271+ VectorSimilarityFunction similarityFunction ,
272+ float [] scores
273+ ) throws IOException {
274+ int limit = size - ES91Int4VectorsScorer .BULK_SIZE + 1 ;
275+ int i = 0 ;
276+ for (; i < limit ; i += ES91Int4VectorsScorer .BULK_SIZE ) {
277+ scorer .scoreBulk (
278+ quantizeQuery ,
279+ queryCorrections .lowerInterval (),
280+ queryCorrections .upperInterval (),
281+ queryCorrections .quantizedComponentSum (),
282+ queryCorrections .additionalCorrection (),
283+ similarityFunction ,
284+ centroidDp ,
285+ scores
286+ );
287+ for (int j = 0 ; j < ES91Int4VectorsScorer .BULK_SIZE ; j ++) {
288+ neighborQueue .add (scoresOffset + i + j , scores [j ]);
127289 }
128- return Math .max ((1f + score ) / 2f , 0 );
129290 }
291+
292+ for (; i < size ; i ++) {
293+ float score = int4QuantizedScore (
294+ centroids ,
295+ scorer ,
296+ quantizeQuery ,
297+ queryCorrections ,
298+ targetCorrections ,
299+ centroidDp ,
300+ similarityFunction
301+ );
302+ neighborQueue .add (scoresOffset + i , score );
303+ }
304+ }
305+
306+ private float int4QuantizedScore (
307+ IndexInput centroids ,
308+ ES91Int4VectorsScorer scorer ,
309+ byte [] quantizeQuery ,
310+ OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
311+ float [] targetCorrections ,
312+ float centroidDp ,
313+ VectorSimilarityFunction similarityFunction
314+ ) throws IOException {
315+ float qcDist = scorer .int4DotProduct (quantizeQuery );
316+ centroids .readFloats (targetCorrections , 0 , 3 );
317+ final int targetComponentSum = Short .toUnsignedInt (centroids .readShort ());
318+ return scorer .applyCorrections (
319+ queryCorrections .lowerInterval (),
320+ queryCorrections .upperInterval (),
321+ queryCorrections .quantizedComponentSum (),
322+ queryCorrections .additionalCorrection (),
323+ similarityFunction ,
324+ centroidDp ,
325+ targetCorrections [0 ],
326+ targetCorrections [1 ],
327+ targetComponentSum ,
328+ targetCorrections [2 ],
329+ qcDist
330+ );
130331 }
131332
132333 @ Override
0 commit comments