@@ -48,7 +48,7 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
48
48
}
49
49
50
50
@ Override
51
- CentroidQueryScorer getCentroidScorer (FieldInfo fieldInfo , int numCentroids , IndexInput centroids , float [] targetQuery )
51
+ CentroidIterator getCentroidIterator (FieldInfo fieldInfo , int numCentroids , IndexInput centroids , float [] targetQuery )
52
52
throws IOException {
53
53
final FieldEntry fieldEntry = fields .get (fieldInfo .number );
54
54
final float globalCentroidDp = fieldEntry .globalCentroidDp ();
@@ -65,90 +65,68 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
65
65
quantized [i ] = (byte ) scratch [i ];
66
66
}
67
67
final ES91Int4VectorsScorer scorer = ESVectorUtil .getES91Int4VectorsScorer (centroids , fieldInfo .getVectorDimension ());
68
- return new CentroidQueryScorer () {
69
- int currentCentroid = -1 ;
70
- long postingListOffset ;
71
- private final float [] centroidCorrectiveValues = new float [3 ];
72
- private final long quantizeCentroidsLength = (long ) numCentroids * (fieldInfo .getVectorDimension () + 3 * Float .BYTES
73
- + Short .BYTES );
74
-
68
+ NeighborQueue queue = new NeighborQueue (fieldEntry .numCentroids (), true );
69
+ centroids .seek (0L );
70
+ final float [] centroidCorrectiveValues = new float [3 ];
71
+ for (int i = 0 ; i < numCentroids ; i ++) {
72
+ final float qcDist = scorer .int4DotProduct (quantized );
73
+ centroids .readFloats (centroidCorrectiveValues , 0 , 3 );
74
+ final int quantizedCentroidComponentSum = Short .toUnsignedInt (centroids .readShort ());
75
+ float score = int4QuantizedScore (
76
+ qcDist ,
77
+ queryParams ,
78
+ fieldInfo .getVectorDimension (),
79
+ centroidCorrectiveValues ,
80
+ quantizedCentroidComponentSum ,
81
+ globalCentroidDp ,
82
+ fieldInfo .getVectorSimilarityFunction ()
83
+ );
84
+ queue .add (i , score );
85
+ }
86
+ final long offset = centroids .getFilePointer ();
87
+ return new CentroidIterator () {
75
88
@ Override
76
- public int size () {
77
- return numCentroids ;
89
+ public boolean hasNext () {
90
+ return queue . size () > 0 ;
78
91
}
79
92
80
93
@ Override
81
- public long postingListOffset (int centroidOrdinal ) throws IOException {
82
- if (centroidOrdinal != currentCentroid ) {
83
- centroids .seek (quantizeCentroidsLength + (long ) Long .BYTES * centroidOrdinal );
84
- postingListOffset = centroids .readLong ();
85
- currentCentroid = centroidOrdinal ;
86
- }
87
- return postingListOffset ;
88
- }
89
-
90
- public void bulkScore (NeighborQueue queue ) throws IOException {
91
- // TODO: bulk score centroids like we do with posting lists
92
- centroids .seek (0L );
93
- for (int i = 0 ; i < numCentroids ; i ++) {
94
- queue .add (i , score ());
95
- }
96
- }
97
-
98
- private float score () throws IOException {
99
- final float qcDist = scorer .int4DotProduct (quantized );
100
- centroids .readFloats (centroidCorrectiveValues , 0 , 3 );
101
- final int quantizedCentroidComponentSum = Short .toUnsignedInt (centroids .readShort ());
102
- return int4QuantizedScore (
103
- qcDist ,
104
- queryParams ,
105
- fieldInfo .getVectorDimension (),
106
- centroidCorrectiveValues ,
107
- quantizedCentroidComponentSum ,
108
- globalCentroidDp ,
109
- fieldInfo .getVectorSimilarityFunction ()
110
- );
111
- }
112
-
113
- // TODO can we do this in off-heap blocks?
114
- private float int4QuantizedScore (
115
- float qcDist ,
116
- OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
117
- int dims ,
118
- float [] targetCorrections ,
119
- int targetComponentSum ,
120
- float centroidDp ,
121
- VectorSimilarityFunction similarityFunction
122
- ) {
123
- float ax = targetCorrections [0 ];
124
- // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
125
- float lx = (targetCorrections [1 ] - ax ) * FOUR_BIT_SCALE ;
126
- float ay = queryCorrections .lowerInterval ();
127
- float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
128
- float y1 = queryCorrections .quantizedComponentSum ();
129
- float score = ax * ay * dims + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
130
- if (similarityFunction == EUCLIDEAN ) {
131
- score = queryCorrections .additionalCorrection () + targetCorrections [2 ] - 2 * score ;
132
- return Math .max (1 / (1f + score ), 0 );
133
- } else {
134
- // For cosine and max inner product, we need to apply the additional correction, which is
135
- // assumed to be the non-centered dot-product between the vector and the centroid
136
- score += queryCorrections .additionalCorrection () + targetCorrections [2 ] - centroidDp ;
137
- if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
138
- return VectorUtil .scaleMaxInnerProductScore (score );
139
- }
140
- return Math .max ((1f + score ) / 2f , 0 );
141
- }
94
+ public long nextPostingListOffset () throws IOException {
95
+ int centroidOrdinal = queue .pop ();
96
+ centroids .seek (offset + (long ) Long .BYTES * centroidOrdinal );
97
+ return centroids .readLong ();
142
98
}
143
99
};
144
100
}
145
101
146
- @ Override
147
- NeighborQueue scorePostingLists (FieldInfo fieldInfo , KnnCollector knnCollector , CentroidQueryScorer centroidQueryScorer , int nProbe )
148
- throws IOException {
149
- NeighborQueue neighborQueue = new NeighborQueue (centroidQueryScorer .size (), true );
150
- centroidQueryScorer .bulkScore (neighborQueue );
151
- return neighborQueue ;
102
+ // TODO can we do this in off-heap blocks?
103
+ private float int4QuantizedScore (
104
+ float qcDist ,
105
+ OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
106
+ int dims ,
107
+ float [] targetCorrections ,
108
+ int targetComponentSum ,
109
+ float centroidDp ,
110
+ VectorSimilarityFunction similarityFunction
111
+ ) {
112
+ float ax = targetCorrections [0 ];
113
+ float lx = (targetCorrections [1 ] - ax ) * FOUR_BIT_SCALE ;
114
+ float ay = queryCorrections .lowerInterval ();
115
+ float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
116
+ float y1 = queryCorrections .quantizedComponentSum ();
117
+ float score = ax * ay * dims + ay * lx * (float ) targetComponentSum + ax * ly * y1 + lx * ly * qcDist ;
118
+ if (similarityFunction == EUCLIDEAN ) {
119
+ score = queryCorrections .additionalCorrection () + targetCorrections [2 ] - 2 * score ;
120
+ return Math .max (1 / (1f + score ), 0 );
121
+ } else {
122
+ // For cosine and max inner product, we need to apply the additional correction, which is
123
+ // assumed to be the non-centered dot-product between the vector and the centroid
124
+ score += queryCorrections .additionalCorrection () + targetCorrections [2 ] - centroidDp ;
125
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
126
+ return VectorUtil .scaleMaxInnerProductScore (score );
127
+ }
128
+ return Math .max ((1f + score ) / 2f , 0 );
129
+ }
152
130
}
153
131
154
132
@ Override
0 commit comments