@@ -48,7 +48,7 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
4848 }
4949
5050 @ Override
51- CentroidQueryScorer getCentroidScorer (FieldInfo fieldInfo , int numCentroids , IndexInput centroids , float [] targetQuery )
51+ CentroidIterator getCentroidIterator (FieldInfo fieldInfo , int numCentroids , IndexInput centroids , float [] targetQuery )
5252 throws IOException {
5353 final FieldEntry fieldEntry = fields .get (fieldInfo .number );
5454 final float globalCentroidDp = fieldEntry .globalCentroidDp ();
@@ -65,90 +65,68 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
6565 quantized [i ] = (byte ) scratch [i ];
6666 }
6767 final ES91Int4VectorsScorer scorer = ESVectorUtil .getES91Int4VectorsScorer (centroids , fieldInfo .getVectorDimension ());
68- return new CentroidQueryScorer () {
69- int currentCentroid = -1 ;
70- long postingListOffset ;
71- private final float [] centroidCorrectiveValues = new float [3 ];
72- private final long quantizeCentroidsLength = (long ) numCentroids * (fieldInfo .getVectorDimension () + 3 * Float .BYTES
73- + Short .BYTES );
74-
68+ NeighborQueue queue = new NeighborQueue (fieldEntry .numCentroids (), true );
69+ centroids .seek (0L );
70+ final float [] centroidCorrectiveValues = new float [3 ];
71+ for (int i = 0 ; i < numCentroids ; i ++) {
72+ final float qcDist = scorer .int4DotProduct (quantized );
73+ centroids .readFloats (centroidCorrectiveValues , 0 , 3 );
74+ final int quantizedCentroidComponentSum = Short .toUnsignedInt (centroids .readShort ());
75+ float score = int4QuantizedScore (
76+ qcDist ,
77+ queryParams ,
78+ fieldInfo .getVectorDimension (),
79+ centroidCorrectiveValues ,
80+ quantizedCentroidComponentSum ,
81+ globalCentroidDp ,
82+ fieldInfo .getVectorSimilarityFunction ()
83+ );
84+ queue .add (i , score );
85+ }
86+ final long offset = centroids .getFilePointer ();
87+ return new CentroidIterator () {
7588 @ Override
76- public int size () {
77- return numCentroids ;
89+ public boolean hasNext () {
90+ return queue . size () > 0 ;
7891 }
7992
8093 @ Override
81- public long postingListOffset (int centroidOrdinal ) throws IOException {
82- if (centroidOrdinal != currentCentroid ) {
83- centroids .seek (quantizeCentroidsLength + (long ) Long .BYTES * centroidOrdinal );
84- postingListOffset = centroids .readLong ();
85- currentCentroid = centroidOrdinal ;
86- }
87- return postingListOffset ;
88- }
89-
90- public void bulkScore (NeighborQueue queue ) throws IOException {
91- // TODO: bulk score centroids like we do with posting lists
92- centroids .seek (0L );
93- for (int i = 0 ; i < numCentroids ; i ++) {
94- queue .add (i , score ());
95- }
96- }
97-
98- private float score () throws IOException {
99- final float qcDist = scorer .int4DotProduct (quantized );
100- centroids .readFloats (centroidCorrectiveValues , 0 , 3 );
101- final int quantizedCentroidComponentSum = Short .toUnsignedInt (centroids .readShort ());
102- return int4QuantizedScore (
103- qcDist ,
104- queryParams ,
105- fieldInfo .getVectorDimension (),
106- centroidCorrectiveValues ,
107- quantizedCentroidComponentSum ,
108- globalCentroidDp ,
109- fieldInfo .getVectorSimilarityFunction ()
110- );
111- }
112-
113- // TODO can we do this in off-heap blocks?
114- private float int4QuantizedScore (
115- float qcDist ,
116- OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
117- int dims ,
118- float [] targetCorrections ,
119- int targetComponentSum ,
120- float centroidDp ,
121- VectorSimilarityFunction similarityFunction
122- ) {
123- float ax = targetCorrections [0 ];
124- // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
125- float lx = (targetCorrections [1 ] - ax ) * FOUR_BIT_SCALE ;
126- float ay = queryCorrections .lowerInterval ();
127- float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
128- float y1 = queryCorrections .quantizedComponentSum ();
129- float score = ax * ay * dims + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
130- if (similarityFunction == EUCLIDEAN ) {
131- score = queryCorrections .additionalCorrection () + targetCorrections [2 ] - 2 * score ;
132- return Math .max (1 / (1f + score ), 0 );
133- } else {
134- // For cosine and max inner product, we need to apply the additional correction, which is
135- // assumed to be the non-centered dot-product between the vector and the centroid
136- score += queryCorrections .additionalCorrection () + targetCorrections [2 ] - centroidDp ;
137- if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
138- return VectorUtil .scaleMaxInnerProductScore (score );
139- }
140- return Math .max ((1f + score ) / 2f , 0 );
141- }
94+ public long nextPostingListOffset () throws IOException {
95+ int centroidOrdinal = queue .pop ();
96+ centroids .seek (offset + (long ) Long .BYTES * centroidOrdinal );
97+ return centroids .readLong ();
14298 }
14399 };
144100 }
145101
146- @ Override
147- NeighborQueue scorePostingLists (FieldInfo fieldInfo , KnnCollector knnCollector , CentroidQueryScorer centroidQueryScorer , int nProbe )
148- throws IOException {
149- NeighborQueue neighborQueue = new NeighborQueue (centroidQueryScorer .size (), true );
150- centroidQueryScorer .bulkScore (neighborQueue );
151- return neighborQueue ;
102+ // TODO can we do this in off-heap blocks?
103+ private float int4QuantizedScore (
104+ float qcDist ,
105+ OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
106+ int dims ,
107+ float [] targetCorrections ,
108+ int targetComponentSum ,
109+ float centroidDp ,
110+ VectorSimilarityFunction similarityFunction
111+ ) {
112+ float ax = targetCorrections [0 ];
113+ float lx = (targetCorrections [1 ] - ax ) * FOUR_BIT_SCALE ;
114+ float ay = queryCorrections .lowerInterval ();
115+ float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
116+ float y1 = queryCorrections .quantizedComponentSum ();
117+ float score = ax * ay * dims + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
118+ if (similarityFunction == EUCLIDEAN ) {
119+ score = queryCorrections .additionalCorrection () + targetCorrections [2 ] - 2 * score ;
120+ return Math .max (1 / (1f + score ), 0 );
121+ } else {
122+ // For cosine and max inner product, we need to apply the additional correction, which is
123+ // assumed to be the non-centered dot-product between the vector and the centroid
124+ score += queryCorrections .additionalCorrection () + targetCorrections [2 ] - centroidDp ;
125+ if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
126+ return VectorUtil .scaleMaxInnerProductScore (score );
127+ }
128+ return Math .max ((1f + score ) / 2f , 0 );
129+ }
152130 }
153131
154132 @ Override
0 commit comments