@@ -46,6 +46,42 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
4646 super (state , rawVectorsReader );
4747 }
4848
49+ private abstract static class BaseCentroidQueryScorer implements CentroidQueryScorer {
50+
51+ // TODO can we do this in off-heap blocks?
52+ float int4QuantizedScore (
53+ float qcDist ,
54+ OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
55+ int dims ,
56+ float [] targetCorrections ,
57+ int targetComponentSum ,
58+ float centroidDp ,
59+ VectorSimilarityFunction similarityFunction
60+ ) {
61+ float ax = targetCorrections [0 ];
62+ // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
63+ float lx = (targetCorrections [1 ] - ax ) * FOUR_BIT_SCALE ;
64+ float ay = queryCorrections .lowerInterval ();
65+ float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
66+ float y1 = queryCorrections .quantizedComponentSum ();
67+ float score = ax * ay * dims + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
68+ if (similarityFunction == EUCLIDEAN ) {
69+ score = queryCorrections .additionalCorrection () + targetCorrections [2 ] - 2 * score ;
70+ return Math .max (1 / (1f + score ), 0 );
71+ } else {
72+ // For cosine and max inner product, we need to apply the additional correction, which is
73+ // assumed to be the non-centered dot-product between the vector and the centroid
74+ score += queryCorrections .additionalCorrection () + targetCorrections [2 ] - centroidDp ;
75+ if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
76+ return VectorUtil .scaleMaxInnerProductScore (score );
77+ }
78+ return Math .max ((1f + score ) / 2f , 0 );
79+ }
80+ }
81+ }
82+
83+ private abstract static class ParentCentroidQueryScorer extends BaseCentroidQueryScorer implements CentroidWChildrenQueryScorer {}
84+
4985 @ Override
5086 CentroidQueryScorer getCentroidScorer (
5187 FieldInfo fieldInfo ,
@@ -65,7 +101,7 @@ CentroidQueryScorer getCentroidScorer(
65101 fieldEntry .globalCentroid ()
66102 );
67103 final ES91Int4VectorsScorer scorer = ESVectorUtil .getES91Int4VectorsScorer (centroids , fieldInfo .getVectorDimension ());
68- return new CentroidQueryScorer () {
104+ return new BaseCentroidQueryScorer () {
69105 int currentCentroid = -1 ;
70106 private final float [] centroid = new float [fieldInfo .getVectorDimension ()];
71107 private final float [] centroidCorrectiveValues = new float [3 ];
@@ -90,6 +126,7 @@ public float[] centroid(int centroidOrdinal) throws IOException {
90126 return centroid ;
91127 }
92128
129+ @ Override
93130 public void bulkScore (NeighborQueue queue ) throws IOException {
94131 // TODO: bulk score centroids like we do with posting lists
95132 centroids .seek (quantizedCentroidsOffset );
@@ -121,44 +158,16 @@ private float score() throws IOException {
121158 fieldInfo .getVectorSimilarityFunction ()
122159 );
123160 }
124-
125- // TODO can we do this in off-heap blocks?
126- private float int4QuantizedScore (
127- float qcDist ,
128- OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
129- int dims ,
130- float [] targetCorrections ,
131- int targetComponentSum ,
132- float centroidDp ,
133- VectorSimilarityFunction similarityFunction
134- ) {
135- float ax = targetCorrections [0 ];
136- // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
137- float lx = (targetCorrections [1 ] - ax ) * FOUR_BIT_SCALE ;
138- float ay = queryCorrections .lowerInterval ();
139- float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
140- float y1 = queryCorrections .quantizedComponentSum ();
141- float score = ax * ay * dims + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
142- if (similarityFunction == EUCLIDEAN ) {
143- score = queryCorrections .additionalCorrection () + targetCorrections [2 ] - 2 * score ;
144- return Math .max (1 / (1f + score ), 0 );
145- } else {
146- // For cosine and max inner product, we need to apply the additional correction, which is
147- // assumed to be the non-centered dot-product between the vector and the centroid
148- score += queryCorrections .additionalCorrection () + targetCorrections [2 ] - centroidDp ;
149- if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
150- return VectorUtil .scaleMaxInnerProductScore (score );
151- }
152- return Math .max ((1f + score ) / 2f , 0 );
153- }
154- }
155161 };
156162 }
157163
158- // FIXME: clean up duplicative code between the scorers
159164 @ Override
160- ParentCentroidQueryScorer getParentCentroidScorer (FieldInfo fieldInfo , int numCentroids , IndexInput centroids , float [] targetQuery )
161- throws IOException {
165+ ParentCentroidQueryScorer getParentCentroidScorer (
166+ FieldInfo fieldInfo ,
167+ int numParentCentroids ,
168+ IndexInput centroids ,
169+ float [] targetQuery
170+ ) throws IOException {
162171 FieldEntry fieldEntry = fields .get (fieldInfo .number );
163172 float [] globalCentroid = fieldEntry .globalCentroid ();
164173 float globalCentroidDp = fieldEntry .globalCentroidDp ();
@@ -183,15 +192,15 @@ ParentCentroidQueryScorer getParentCentroidScorer(FieldInfo fieldInfo, int numCe
183192
184193 @ Override
185194 public int size () {
186- return numCentroids ;
195+ return numParentCentroids ;
187196 }
188197
189198 @ Override
190199 public float [] centroid (int centroidOrdinal ) throws IOException {
191200 throw new UnsupportedOperationException ("can't score at the parent level" );
192201 }
193202
194- private void readQuantizedAndRawCentroid (int centroidOrdinal ) throws IOException {
203+ private void readChildDetails (int centroidOrdinal ) throws IOException {
195204 if (centroidOrdinal == currentCentroid ) {
196205 return ;
197206 }
@@ -201,28 +210,29 @@ private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException
201210 currentCentroid = centroidOrdinal ;
202211 }
203212
213+ @ Override
204214 public int getChildCentroidStart (int centroidOrdinal ) throws IOException {
205- readQuantizedAndRawCentroid (centroidOrdinal );
215+ readChildDetails (centroidOrdinal );
206216 return childCentroidStart ;
207217 }
208218
219+ @ Override
209220 public int getChildCount (int centroidOrdinal ) throws IOException {
210- readQuantizedAndRawCentroid (centroidOrdinal );
221+ readChildDetails (centroidOrdinal );
211222 return childCount ;
212223 }
213224
214225 @ Override
215226 public void bulkScore (NeighborQueue queue ) throws IOException {
216227 // TODO: bulk score centroids like we do with posting lists
217228 centroids .seek (0L );
218- for (int i = 0 ; i < numCentroids ; i ++) {
229+ for (int i = 0 ; i < numParentCentroids ; i ++) {
219230 queue .add (i , score ());
220231 }
221232 }
222233
223234 @ Override
224235 public void bulkScore (NeighborQueue queue , int start , int end ) throws IOException {
225- // FIXME: this never gets used ... I wonder if we just need an entirely different interface for this
226236 // TODO: bulk score centroids like we do with posting lists
227237 centroids .seek (parentNodeByteSize * start );
228238 for (int i = start ; i < end ; i ++) {
@@ -235,10 +245,10 @@ private float score() throws IOException {
235245 centroids .readFloats (centroidCorrectiveValues , 0 , 3 );
236246 final int quantizedCentroidComponentSum = Short .toUnsignedInt (centroids .readShort ());
237247
238- // FIXME: move these now? to a different place in the file?
248+ // TODO: should we consider a different format such as moving these to the beginning of the file to benefit bulk read
239249 // TODO: cache these at this point when scoring since we'll likely read many of them?
240- centroids . readInt (); // child partition start
241- centroids .readInt (); // child partition count
250+ // child partition start, child partition count
251+ centroids .skipBytes ( Integer . BYTES * 2 );
242252
243253 return int4QuantizedScore (
244254 qcDist ,
@@ -250,46 +260,27 @@ private float score() throws IOException {
250260 fieldInfo .getVectorSimilarityFunction ()
251261 );
252262 }
253-
254- // TODO can we do this in off-heap blocks?
255- private float int4QuantizedScore (
256- float qcDist ,
257- OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
258- int dims ,
259- float [] targetCorrections ,
260- int targetComponentSum ,
261- float centroidDp ,
262- VectorSimilarityFunction similarityFunction
263- ) {
264- float ax = targetCorrections [0 ];
265- // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
266- float lx = (targetCorrections [1 ] - ax ) * FOUR_BIT_SCALE ;
267- float ay = queryCorrections .lowerInterval ();
268- float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
269- float y1 = queryCorrections .quantizedComponentSum ();
270- float score = ax * ay * dims + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
271- if (similarityFunction == EUCLIDEAN ) {
272- score = queryCorrections .additionalCorrection () + targetCorrections [2 ] - 2 * score ;
273- return Math .max (1 / (1f + score ), 0 );
274- } else {
275- // For cosine and max inner product, we need to apply the additional correction, which is
276- // assumed to be the non-centered dot-product between the vector and the centroid
277- score += queryCorrections .additionalCorrection () + targetCorrections [2 ] - centroidDp ;
278- if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
279- return VectorUtil .scaleMaxInnerProductScore (score );
280- }
281- return Math .max ((1f + score ) / 2f , 0 );
282- }
283- }
284263 };
285264 }
286265
266+ @ Override
267+ NeighborQueue scorePostingLists (
268+ FieldInfo fieldInfo ,
269+ KnnCollector knnCollector ,
270+ CentroidQueryScorer centroidQueryScorer ,
271+ int nProbe ,
272+ int start ,
273+ int count
274+ ) throws IOException {
275+ NeighborQueue neighborQueue = new NeighborQueue (count , true );
276+ centroidQueryScorer .bulkScore (neighborQueue , start , start + count );
277+ return neighborQueue ;
278+ }
279+
287280 @ Override
288281 NeighborQueue scorePostingLists (FieldInfo fieldInfo , KnnCollector knnCollector , CentroidQueryScorer centroidQueryScorer , int nProbe )
289282 throws IOException {
290- NeighborQueue neighborQueue = new NeighborQueue (centroidQueryScorer .size (), true );
291- centroidQueryScorer .bulkScore (neighborQueue );
292- return neighborQueue ;
283+ return scorePostingLists (fieldInfo , knnCollector , centroidQueryScorer , nProbe , 0 , centroidQueryScorer .size ());
293284 }
294285
295286 @ Override
0 commit comments