@@ -54,9 +54,10 @@ type RaBitQuantizer struct {
54
54
// to the statically-allocated arrays in this struct.
55
55
type raBitQuantizedVector struct {
56
56
RaBitQuantizedVectorSet
57
- codeCountStorage [1 ]uint32
58
- centroidDistanceStorage [1 ]float32
59
- dotProductStorage [1 ]float32
57
+ codeCountStorage [1 ]uint32
58
+ centroidDistanceStorage [1 ]float32
59
+ quantizedDotProductStorage [1 ]float32
60
+ centroidDotProductStorage [1 ]float32
60
61
}
61
62
62
63
var _ Quantizer = (* RaBitQuantizer )(nil )
@@ -134,7 +135,13 @@ func (q *RaBitQuantizer) NewQuantizedVectorSet(capacity int, centroid vector.T)
134
135
quantized .Codes = MakeRaBitQCodeSetFromRawData (dataBuffer , codeWidth )
135
136
quantized .CodeCounts = quantized .codeCountStorage [:0 ]
136
137
quantized .CentroidDistances = quantized .centroidDistanceStorage [:0 ]
137
- quantized .QuantizedDotProducts = quantized .dotProductStorage [:0 ]
138
+ quantized .QuantizedDotProducts = quantized .quantizedDotProductStorage [:0 ]
139
+
140
+ // L2Squared doesn't use these, so don't make extra calculations.
141
+ if q .distanceMetric != vecpb .L2SquaredDistance {
142
+ quantized .CentroidDotProducts = quantized .centroidDotProductStorage [:0 ]
143
+ quantized .CentroidNorm = num32 .Norm (centroid )
144
+ }
138
145
return & quantized .RaBitQuantizedVectorSet
139
146
}
140
147
@@ -145,9 +152,10 @@ func (q *RaBitQuantizer) NewQuantizedVectorSet(capacity int, centroid vector.T)
145
152
CentroidDistances : make ([]float32 , 0 , capacity ),
146
153
QuantizedDotProducts : make ([]float32 , 0 , capacity ),
147
154
}
148
- // L2Squared doesn't use this , so don't make extra allocation.
155
+ // L2Squared doesn't use these , so don't make extra allocation or calculation .
149
156
if q .distanceMetric != vecpb .L2SquaredDistance {
150
157
vs .CentroidDotProducts = make ([]float32 , 0 , capacity )
158
+ vs .CentroidNorm = num32 .Norm (centroid )
151
159
}
152
160
return vs
153
161
}
@@ -180,28 +188,7 @@ func (q *RaBitQuantizer) EstimateDistances(
180
188
181
189
if queryCentroidDistance == 0 {
182
190
// The query vector is the centroid.
183
- switch q .distanceMetric {
184
- case vecpb .L2SquaredDistance :
185
- // The distance from the query to the data vectors are just the centroid
186
- // distances that have already been calculated, but just need to be
187
- // squared.
188
- num32 .MulTo (distances , raBitSet .CentroidDistances , raBitSet .CentroidDistances )
189
-
190
- case vecpb .InnerProductDistance :
191
- // The dot products between the centroid and the data vectors have
192
- // already been computed, just need to negate them.
193
- num32 .ScaleTo (distances , - 1 , raBitSet .CentroidDotProducts )
194
-
195
- case vecpb .CosineDistance :
196
- // All vectors have been normalized, so cosine distance = 1 - dot product.
197
- num32 .ScaleTo (distances , - 1 , raBitSet .CentroidDotProducts )
198
- num32 .AddConst (1 , distances )
199
-
200
- default :
201
- panic (errors .AssertionFailedf (
202
- "RaBitQuantizer does not support distance metric %s" , q .distanceMetric ))
203
- }
204
-
191
+ q .GetCentroidDistances (quantizedSet , distances , false /* spherical */ )
205
192
num32 .Zero (errorBounds )
206
193
return
207
194
}
@@ -210,7 +197,7 @@ func (q *RaBitQuantizer) EstimateDistances(
210
197
var squaredCentroidNorm , queryCentroidDotProduct float32
211
198
if q .distanceMetric != vecpb .L2SquaredDistance {
212
199
queryCentroidDotProduct = num32 .Dot (queryVector , raBitSet .Centroid )
213
- squaredCentroidNorm = num32 . SquaredNorm ( raBitSet .Centroid )
200
+ squaredCentroidNorm = raBitSet . CentroidNorm * raBitSet .CentroidNorm
214
201
}
215
202
216
203
tempQueryUnitVector := tempQueryDiff
@@ -371,6 +358,48 @@ func (q *RaBitQuantizer) EstimateDistances(
371
358
}
372
359
}
373
360
361
+ // GetCentroidDistances implements the Quantizer interface.
362
+ func (q * RaBitQuantizer ) GetCentroidDistances (
363
+ quantizedSet QuantizedVectorSet , distances []float32 , spherical bool ,
364
+ ) {
365
+ raBitSet := quantizedSet .(* RaBitQuantizedVectorSet )
366
+
367
+ switch q .distanceMetric {
368
+ case vecpb .L2SquaredDistance :
369
+ // The distance from the query to the data vectors are just the centroid
370
+ // distances that have already been calculated, but just need to be
371
+ // squared.
372
+ num32 .MulTo (distances , raBitSet .CentroidDistances , raBitSet .CentroidDistances )
373
+
374
+ case vecpb .InnerProductDistance :
375
+ // Need to negate precomputed centroid dot products to compute inner
376
+ // product distance.
377
+ multiplier := float32 (- 1 )
378
+ if spherical && raBitSet .CentroidNorm != 0 {
379
+ // Convert the mean centroid dot product into a spherical centroid
380
+ // dot product by dividing by the centroid's norm.
381
+ multiplier /= raBitSet .CentroidNorm
382
+ }
383
+ num32 .ScaleTo (distances , multiplier , raBitSet .CentroidDotProducts )
384
+
385
+ case vecpb .CosineDistance :
386
+ // Cosine distance = 1 - dot product when vectors are normalized. The
387
+ // precomputed centroid dot products were computed with normalized data
388
+ // vectors, but the centroid was not normalized. Do that now by dividing
389
+ // the dot products by the centroid's norm. Also negate the result.
390
+ multiplier := float32 (- 1 )
391
+ if raBitSet .CentroidNorm != 0 {
392
+ multiplier /= raBitSet .CentroidNorm
393
+ }
394
+ num32 .ScaleTo (distances , multiplier , raBitSet .CentroidDotProducts )
395
+ num32 .AddConst (1 , distances )
396
+
397
+ default :
398
+ panic (errors .AssertionFailedf (
399
+ "RaBitQuantizer does not support distance metric %s" , q .distanceMetric ))
400
+ }
401
+ }
402
+
374
403
// quantizeHelper quantizes the given set of vectors and adds the quantization
375
404
// information to the provided quantized vector set.
376
405
func (q * RaBitQuantizer ) quantizeHelper (
0 commit comments