88 */
99package org .elasticsearch .simdvec ;
1010
11+ import org .apache .lucene .index .VectorSimilarityFunction ;
1112import org .apache .lucene .store .IndexInput ;
13+ import org .apache .lucene .util .VectorUtil ;
1214
1315import java .io .IOException ;
1416
17+ import static org .apache .lucene .index .VectorSimilarityFunction .EUCLIDEAN ;
18+ import static org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
19+
1520/** Scorer for quantized vectors stored as an {@link IndexInput}.
1621 * <p>
1722 * Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but
2025 * */
2126public class ES91Int4VectorsScorer {
2227
28+ public static final int BULK_SIZE = 16 ;
29+ protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4 ) - 1 );
30+
2331 /** The wrapper {@link IndexInput}. */
2432 protected final IndexInput in ;
2533 protected final int dimensions ;
2634 protected byte [] scratch ;
2735
36+ protected final float [] lowerIntervals = new float [BULK_SIZE ];
37+ protected final float [] upperIntervals = new float [BULK_SIZE ];
38+ protected final int [] targetComponentSums = new int [BULK_SIZE ];
39+ protected final float [] additionalCorrections = new float [BULK_SIZE ];
40+
2841 /** Sole constructor, called by sub-classes. */
2942 public ES91Int4VectorsScorer (IndexInput in , int dimensions ) {
3043 this .in = in ;
3144 this .dimensions = dimensions ;
3245 scratch = new byte [dimensions ];
3346 }
3447
48+ /**
49+ * compute the quantize distance between the provided quantized query and the quantized vector
50+ * that is read from the wrapped {@link IndexInput}.
51+ */
3552 public long int4DotProduct (byte [] b ) throws IOException {
3653 in .readBytes (scratch , 0 , dimensions );
3754 int total = 0 ;
@@ -40,4 +57,129 @@ public long int4DotProduct(byte[] b) throws IOException {
4057 }
4158 return total ;
4259 }
60+
61+ /**
62+ * compute the quantize distance between the provided quantized query and the quantized vectors
63+ * that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
64+ * determined by {code count} and the results are stored in the provided {@code scores} array.
65+ */
66+ public void int4DotProductBulk (byte [] b , int count , float [] scores ) throws IOException {
67+ for (int i = 0 ; i < count ; i ++) {
68+ scores [i ] = int4DotProduct (b );
69+ }
70+ }
71+
72+ /**
73+ * Computes the score by applying the necessary corrections to the provided quantized distance.
74+ */
75+ public float score (
76+ byte [] q ,
77+ float queryLowerInterval ,
78+ float queryUpperInterval ,
79+ int queryComponentSum ,
80+ float queryAdditionalCorrection ,
81+ VectorSimilarityFunction similarityFunction ,
82+ float centroidDp
83+ ) throws IOException {
84+ float score = int4DotProduct (q );
85+ in .readFloats (lowerIntervals , 0 , 3 );
86+ int addition = Short .toUnsignedInt (in .readShort ());
87+ return applyCorrections (
88+ queryLowerInterval ,
89+ queryUpperInterval ,
90+ queryComponentSum ,
91+ queryAdditionalCorrection ,
92+ similarityFunction ,
93+ centroidDp ,
94+ lowerIntervals [0 ],
95+ lowerIntervals [1 ],
96+ addition ,
97+ lowerIntervals [2 ],
98+ score
99+ );
100+ }
101+
102+ /**
103+ * compute the distance between the provided quantized query and the quantized vectors that are
104+ * read from the wrapped {@link IndexInput}.
105+ *
106+ * <p>The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
107+ * input is as follows: First the quantized vectors are read from the input,then all the lower
108+ * intervals as floats, then all the upper intervals as floats, then all the target component sums
109+ * as shorts, and finally all the additional corrections as floats.
110+ *
111+ * <p>The results are stored in the provided scores array.
112+ */
113+ public void scoreBulk (
114+ byte [] q ,
115+ float queryLowerInterval ,
116+ float queryUpperInterval ,
117+ int queryComponentSum ,
118+ float queryAdditionalCorrection ,
119+ VectorSimilarityFunction similarityFunction ,
120+ float centroidDp ,
121+ float [] scores
122+ ) throws IOException {
123+ int4DotProductBulk (q , BULK_SIZE , scores );
124+ in .readFloats (lowerIntervals , 0 , BULK_SIZE );
125+ in .readFloats (upperIntervals , 0 , BULK_SIZE );
126+ for (int i = 0 ; i < BULK_SIZE ; i ++) {
127+ targetComponentSums [i ] = Short .toUnsignedInt (in .readShort ());
128+ }
129+ in .readFloats (additionalCorrections , 0 , BULK_SIZE );
130+ for (int i = 0 ; i < BULK_SIZE ; i ++) {
131+ scores [i ] = applyCorrections (
132+ queryLowerInterval ,
133+ queryUpperInterval ,
134+ queryComponentSum ,
135+ queryAdditionalCorrection ,
136+ similarityFunction ,
137+ centroidDp ,
138+ lowerIntervals [i ],
139+ upperIntervals [i ],
140+ targetComponentSums [i ],
141+ additionalCorrections [i ],
142+ scores [i ]
143+ );
144+ }
145+ }
146+
147+ /**
148+ * Computes the score by applying the necessary corrections to the provided quantized distance.
149+ */
150+ public float applyCorrections (
151+ float queryLowerInterval ,
152+ float queryUpperInterval ,
153+ int queryComponentSum ,
154+ float queryAdditionalCorrection ,
155+ VectorSimilarityFunction similarityFunction ,
156+ float centroidDp ,
157+ float lowerInterval ,
158+ float upperInterval ,
159+ int targetComponentSum ,
160+ float additionalCorrection ,
161+ float qcDist
162+ ) {
163+ float ax = lowerInterval ;
164+ // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
165+ float lx = upperInterval - ax ;
166+ float ay = queryLowerInterval ;
167+ float ly = (queryUpperInterval - ay ) * FOUR_BIT_SCALE ;
168+ float y1 = queryComponentSum ;
169+ float score = ax * ay * dimensions + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
170+ // For euclidean, we need to invert the score and apply the additional correction, which is
171+ // assumed to be the squared l2norm of the centroid centered vectors.
172+ if (similarityFunction == EUCLIDEAN ) {
173+ score = queryAdditionalCorrection + additionalCorrection - 2 * score ;
174+ return Math .max (1 / (1f + score ), 0 );
175+ } else {
176+ // For cosine and max inner product, we need to apply the additional correction, which is
177+ // assumed to be the non-centered dot-product between the vector and the centroid
178+ score += queryAdditionalCorrection + additionalCorrection - centroidDp ;
179+ if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
180+ return VectorUtil .scaleMaxInnerProductScore (score );
181+ }
182+ return Math .max ((1f + score ) / 2f , 0 );
183+ }
184+ }
43185}
0 commit comments