@@ -54,9 +54,10 @@ type RaBitQuantizer struct {
54
54
// to the statically-allocated arrays in this struct.
55
55
type raBitQuantizedVector struct {
56
56
RaBitQuantizedVectorSet
57
- codeCountStorage [1 ]uint32
58
- centroidDistanceStorage [1 ]float32
59
- dotProductStorage [1 ]float32
57
+ codeCountStorage [1 ]uint32
58
+ centroidDistanceStorage [1 ]float32
59
+ quantizedDotProductStorage [1 ]float32
60
+ centroidDotProductStorage [1 ]float32
60
61
}
61
62
62
63
var _ Quantizer = (* RaBitQuantizer )(nil )
@@ -111,7 +112,7 @@ func (q *RaBitQuantizer) Quantize(w *workspace.T, vectors vector.Set) QuantizedV
111
112
centroid = vectors .Centroid (make (vector.T , vectors .Dims ))
112
113
}
113
114
114
- quantizedSet := q .NewQuantizedVectorSet (vectors .Count , centroid )
115
+ quantizedSet := q .NewSet (vectors .Count , centroid )
115
116
q .quantizeHelper (w , quantizedSet .(* RaBitQuantizedVectorSet ), vectors )
116
117
return quantizedSet
117
118
}
@@ -123,32 +124,43 @@ func (q *RaBitQuantizer) QuantizeInSet(
123
124
q .quantizeHelper (w , quantizedSet .(* RaBitQuantizedVectorSet ), vectors )
124
125
}
125
126
126
- // NewQuantizedVectorSet implements the Quantizer interface
127
- func (q * RaBitQuantizer ) NewQuantizedVectorSet (capacity int , centroid vector.T ) QuantizedVectorSet {
128
- codeWidth := RaBitQCodeSetWidth ( q . GetDims ())
129
- dataBuffer := make ([] uint64 , 0 , capacity * codeWidth )
127
+ // NewSet implements the Quantizer interface
128
+ func (q * RaBitQuantizer ) NewSet (capacity int , centroid vector.T ) QuantizedVectorSet {
129
+ var vs * RaBitQuantizedVectorSet
130
+
130
131
if capacity <= 1 {
131
132
// Special case capacity of zero or one by using in-line storage.
132
- var quantized raBitQuantizedVector
133
- quantized .Centroid = centroid
134
- quantized .Codes = MakeRaBitQCodeSetFromRawData (dataBuffer , codeWidth )
133
+ quantized := & raBitQuantizedVector {}
135
134
quantized .CodeCounts = quantized .codeCountStorage [:0 ]
136
135
quantized .CentroidDistances = quantized .centroidDistanceStorage [:0 ]
137
- quantized .QuantizedDotProducts = quantized .dotProductStorage [:0 ]
138
- return & quantized .RaBitQuantizedVectorSet
139
- }
136
+ quantized .QuantizedDotProducts = quantized .quantizedDotProductStorage [:0 ]
140
137
141
- vs := & RaBitQuantizedVectorSet {
142
- Centroid : centroid ,
143
- Codes : MakeRaBitQCodeSetFromRawData (dataBuffer , codeWidth ),
144
- CodeCounts : make ([]uint32 , 0 , capacity ),
145
- CentroidDistances : make ([]float32 , 0 , capacity ),
146
- QuantizedDotProducts : make ([]float32 , 0 , capacity ),
138
+ // L2Squared doesn't use this.
139
+ if q .distanceMetric != vecpb .L2SquaredDistance {
140
+ quantized .CentroidDotProducts = quantized .centroidDotProductStorage [:0 ]
141
+ }
142
+ vs = & quantized .RaBitQuantizedVectorSet
143
+ } else {
144
+ vs = & RaBitQuantizedVectorSet {
145
+ CodeCounts : make ([]uint32 , 0 , capacity ),
146
+ CentroidDistances : make ([]float32 , 0 , capacity ),
147
+ QuantizedDotProducts : make ([]float32 , 0 , capacity ),
148
+ }
149
+ // L2Squared doesn't use these, so don't make extra allocation or calculation.
150
+ if q .distanceMetric != vecpb .L2SquaredDistance {
151
+ vs .CentroidDotProducts = make ([]float32 , 0 , capacity )
152
+ }
147
153
}
148
- // L2Squared doesn't use this, so don't make extra allocation.
154
+
155
+ vs .Metric = q .distanceMetric
156
+ vs .Centroid = centroid
157
+ codeWidth := RaBitQCodeSetWidth (q .GetDims ())
158
+ dataBuffer := make ([]uint64 , 0 , capacity * codeWidth )
159
+ vs .Codes = MakeRaBitQCodeSetFromRawData (dataBuffer , codeWidth )
149
160
if q .distanceMetric != vecpb .L2SquaredDistance {
150
- vs .CentroidDotProducts = make ([] float32 , 0 , capacity )
161
+ vs .CentroidNorm = num32 . Norm ( centroid )
151
162
}
163
+
152
164
return vs
153
165
}
154
166
@@ -180,28 +192,7 @@ func (q *RaBitQuantizer) EstimateDistances(
180
192
181
193
if queryCentroidDistance == 0 {
182
194
// The query vector is the centroid.
183
- switch q .distanceMetric {
184
- case vecpb .L2SquaredDistance :
185
- // The distance from the query to the data vectors are just the centroid
186
- // distances that have already been calculated, but just need to be
187
- // squared.
188
- num32 .MulTo (distances , raBitSet .CentroidDistances , raBitSet .CentroidDistances )
189
-
190
- case vecpb .InnerProductDistance :
191
- // The dot products between the centroid and the data vectors have
192
- // already been computed, just need to negate them.
193
- num32 .ScaleTo (distances , - 1 , raBitSet .CentroidDotProducts )
194
-
195
- case vecpb .CosineDistance :
196
- // All vectors have been normalized, so cosine distance = 1 - dot product.
197
- num32 .ScaleTo (distances , - 1 , raBitSet .CentroidDotProducts )
198
- num32 .AddConst (1 , distances )
199
-
200
- default :
201
- panic (errors .AssertionFailedf (
202
- "RaBitQuantizer does not support distance metric %s" , q .distanceMetric ))
203
- }
204
-
195
+ q .GetCentroidDistances (quantizedSet , distances , false /* spherical */ )
205
196
num32 .Zero (errorBounds )
206
197
return
207
198
}
@@ -210,7 +201,7 @@ func (q *RaBitQuantizer) EstimateDistances(
210
201
var squaredCentroidNorm , queryCentroidDotProduct float32
211
202
if q .distanceMetric != vecpb .L2SquaredDistance {
212
203
queryCentroidDotProduct = num32 .Dot (queryVector , raBitSet .Centroid )
213
- squaredCentroidNorm = num32 . SquaredNorm ( raBitSet .Centroid )
204
+ squaredCentroidNorm = raBitSet . CentroidNorm * raBitSet .CentroidNorm
214
205
}
215
206
216
207
tempQueryUnitVector := tempQueryDiff
@@ -371,6 +362,48 @@ func (q *RaBitQuantizer) EstimateDistances(
371
362
}
372
363
}
373
364
365
+ // GetCentroidDistances implements the Quantizer interface.
366
+ func (q * RaBitQuantizer ) GetCentroidDistances (
367
+ quantizedSet QuantizedVectorSet , distances []float32 , spherical bool ,
368
+ ) {
369
+ raBitSet := quantizedSet .(* RaBitQuantizedVectorSet )
370
+
371
+ switch q .distanceMetric {
372
+ case vecpb .L2SquaredDistance :
373
+ // The distance from the query to the data vectors are just the centroid
374
+ // distances that have already been calculated, but just need to be
375
+ // squared.
376
+ num32 .MulTo (distances , raBitSet .CentroidDistances , raBitSet .CentroidDistances )
377
+
378
+ case vecpb .InnerProductDistance :
379
+ // Need to negate precomputed centroid dot products to compute inner
380
+ // product distance.
381
+ multiplier := float32 (- 1 )
382
+ if spherical && raBitSet .CentroidNorm != 0 {
383
+ // Convert the mean centroid dot product into a spherical centroid
384
+ // dot product by dividing by the centroid's norm.
385
+ multiplier /= raBitSet .CentroidNorm
386
+ }
387
+ num32 .ScaleTo (distances , multiplier , raBitSet .CentroidDotProducts )
388
+
389
+ case vecpb .CosineDistance :
390
+ // Cosine distance = 1 - dot product when vectors are normalized. The
391
+ // precomputed centroid dot products were computed with normalized data
392
+ // vectors, but the centroid was not normalized. Do that now by dividing
393
+ // the dot products by the centroid's norm. Also negate the result.
394
+ multiplier := float32 (- 1 )
395
+ if raBitSet .CentroidNorm != 0 {
396
+ multiplier /= raBitSet .CentroidNorm
397
+ }
398
+ num32 .ScaleTo (distances , multiplier , raBitSet .CentroidDotProducts )
399
+ num32 .AddConst (1 , distances )
400
+
401
+ default :
402
+ panic (errors .AssertionFailedf (
403
+ "RaBitQuantizer does not support distance metric %s" , q .distanceMetric ))
404
+ }
405
+ }
406
+
374
407
// quantizeHelper quantizes the given set of vectors and adds the quantization
375
408
// information to the provided quantized vector set.
376
409
func (q * RaBitQuantizer ) quantizeHelper (
@@ -383,7 +416,7 @@ func (q *RaBitQuantizer) quantizeHelper(
383
416
// Extend any existing slices in the vector set.
384
417
count := vectors .Count
385
418
oldCount := qs .GetCount ()
386
- qs .AddUndefined (count , q . distanceMetric )
419
+ qs .AddUndefined (count )
387
420
388
421
// L2Squared doesn't use this, so don't store it.
389
422
if q .distanceMetric != vecpb .L2SquaredDistance {
0 commit comments