Skip to content

Commit abf1a09

Browse files
craig[bot]andy-kimball
andcommitted
Merge #148427
148427: cspann: add GetCentroidDistances method to Quantizer interface r=drewkimball a=andy-kimball #### quantize: track distance metric in RaBitQuantizedSet Include the distance metric in RaBitQuantizedSet, since its needed for methods like Clear and AddUndefined. Also, clean up the encoding methods for RaBitQVector a bit. #### quantize: rename NewQuantizedVectorSet to NewSet Rename the NewQuantizedVectorSet method on the Quantizer interface to NewSet. "QuantizedVector" is already implied by context and consistent with other methdods like "QuantizeInSet". #### quantize: add GetCentroidDistances method to Quantizer interface GetCentroidDistances returns the exact distance of each vector in a quantized set to the centroid of that set, according to the quantizer's distance metric (e.g. L2Squared or Cosine). It also supports distance to the spherical centroid for Cosine and InnerProduct metrics. This method will be used in a later PR by the partition split code. Co-authored-by: Andrew Kimball <[email protected]>
2 parents de926ba + 6d4850f commit abf1a09

File tree

16 files changed

+312
-123
lines changed

16 files changed

+312
-123
lines changed

pkg/sql/vecindex/cspann/partition.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,6 @@ func (p *Partition) Clear() int {
282282
// CreateEmptyPartition returns an empty partition for the given quantizer and
283283
// level.
284284
func CreateEmptyPartition(quantizer quantize.Quantizer, metadata PartitionMetadata) *Partition {
285-
quantizedSet := quantizer.NewQuantizedVectorSet(0, metadata.Centroid)
285+
quantizedSet := quantizer.NewSet(0, metadata.Centroid)
286286
return NewPartition(metadata, quantizer, quantizedSet, []ChildKey(nil), []ValueBytes(nil))
287287
}

pkg/sql/vecindex/cspann/quantize/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ proto_library(
88
strip_import_prefix = "/pkg",
99
visibility = ["//visibility:public"],
1010
deps = [
11+
"//pkg/sql/vecindex/vecpb:vecpb_proto",
1112
"//pkg/util/vector:vector_proto",
1213
"@com_github_gogo_protobuf//gogoproto:gogo_proto",
1314
],
@@ -20,6 +21,7 @@ go_proto_library(
2021
proto = ":quantize_proto",
2122
visibility = ["//visibility:public"],
2223
deps = [
24+
"//pkg/sql/vecindex/vecpb",
2325
"//pkg/util/vector",
2426
"@com_github_gogo_protobuf//gogoproto",
2527
],
@@ -71,5 +73,6 @@ go_test(
7173
"//pkg/util/vector",
7274
"@com_github_cockroachdb_datadriven//:datadriven",
7375
"@com_github_stretchr_testify//require",
76+
"@org_gonum_v1_gonum//floats/scalar",
7477
],
7578
)

pkg/sql/vecindex/cspann/quantize/quantize.proto

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ syntax = "proto3";
77
package cockroach.sql.vecindex.quantize;
88
option go_package = "github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/quantize";
99

10+
import "sql/vecindex/vecpb/vec.proto";
1011
import "util/vector/vector.proto";
1112
import "gogoproto/gogo.proto";
1213

@@ -32,30 +33,37 @@ message RaBitQCodeSet {
3233
message RaBitQuantizedVectorSet {
3334
option (gogoproto.equal) = true;
3435

36+
// Metric specifies the metric used to compute similarity between vectors in
37+
// the set.
38+
cockroach.sql.vecindex.vecpb.DistanceMetric metric = 1;
3539
// Centroid is the average of vectors in the set, representing its "center of
3640
// mass". Note that the centroid is computed when a vector set is created and
3741
// is not updated when vectors are added or removed.
38-
// NOTE: By default, this is the mean centroid for the L2Squared distance
39-
// metric, but is the spherical centroid for InnerProduct and Cosine metrics.
40-
// The spherical centroid is simply the mean centroid, but normalized.
42+
// NOTE: This is always the mean centroid, regardless of the distance metric.
43+
// The caller is responsible for converting this to a spherical centroid when
44+
// that's needed.
4145
// NOTE: The centroid should be treated as immutable.
42-
repeated float centroid = 1;
46+
repeated float centroid = 2;
4347
// Codes is a set of RaBitQ quantization codes, with one code per quantized
4448
// vector in the set.
45-
RaBitQCodeSet codes = 2 [(gogoproto.nullable) = false];
49+
RaBitQCodeSet codes = 3 [(gogoproto.nullable) = false];
4650
// CodeCounts records the count of "1" bits in each of the quantization codes.
47-
repeated uint32 code_counts = 3;
51+
repeated uint32 code_counts = 4;
4852
// CentroidDistances is a slice of the exact Euclidean distances (non-squared)
4953
// of the original full-size vectors from the centroid.
50-
repeated float centroid_distances = 4;
54+
repeated float centroid_distances = 5;
5155
// QuantizedDotProducts is a slice of the exact inner products between the
5256
// original full-size vectors and their corresponding quantized vectors.
5357
// NOTE: These values have been inverted (1/inner_product) to avoid expensive
5458
// division during distance estimation.
55-
repeated float quantized_dot_products = 5;
59+
repeated float quantized_dot_products = 6;
5660
// CentroidDotProducts is a slice of the exact inner products between the
5761
// original full-size vectors and the centroid.
58-
repeated float centroid_dot_products = 6;
62+
// NOTE: This is always nil when using the L2Squared distance metric.
63+
repeated float centroid_dot_products = 7;
64+
// CentroidNorm is the L2 norm of the mean centroid.
65+
// NOTE: This is always nil when using the L2Squared distance metric.
66+
float centroid_norm = 8;
5967
}
6068

6169
// UnQuantizedVectorSet trivially implements the QuantizedVectorSet interface,

pkg/sql/vecindex/cspann/quantize/quantizer.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ type Quantizer interface {
4141
// vectors.
4242
QuantizeInSet(w *workspace.T, quantizedSet QuantizedVectorSet, vectors vector.Set)
4343

44-
// NewQuantizedVectorSet returns a new empty vector set preallocated to the
45-
// number of vectors specified.
46-
NewQuantizedVectorSet(capacity int, centroid vector.T) QuantizedVectorSet
44+
// NewSet returns a new empty vector set preallocated to the number of vectors
45+
// specified.
46+
NewSet(capacity int, centroid vector.T) QuantizedVectorSet
4747

4848
// EstimateDistances returns the estimated distances of the query vector from
4949
// each data vector represented in the given quantized vector set, as well as
@@ -61,6 +61,18 @@ type Quantizer interface {
6161
distances []float32,
6262
errorBounds []float32,
6363
)
64+
65+
// GetCentroidDistances returns the exact distance of each vector in
66+
// "quantizedSet" from that set's centroid, according to the quantizer's
67+
// distance metric (e.g. L2Squared or Cosine). By default, it returns
68+
// distances to the mean centroid. However, if "spherical" is true and the
69+
// distance metric is Cosine or InnerProduct, then it returns distances to
70+
// the spherical centroid instead.
71+
//
72+
// The caller is responsible for allocating the "distances" slice with length
73+
// equal to the number of quantized vectors in "quantizedSet". The centroid
74+
// distances will be copied into that slice.
75+
GetCentroidDistances(quantizedSet QuantizedVectorSet, distances []float32, spherical bool)
6476
}
6577

6678
// QuantizedVectorSet is the compressed form of an original set of full-size

pkg/sql/vecindex/cspann/quantize/quantizer_test.go

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ func TestQuantizers(t *testing.T) {
4141
case "estimate-distances":
4242
return state.estimateDistances(t, d)
4343

44+
case "get-centroid-distances":
45+
return state.getCentroidDistances(t, d)
46+
4447
case "calculate-recall":
4548
return state.calculateRecall(t, d)
4649

@@ -100,7 +103,7 @@ func (s *testState) estimateDistances(t *testing.T, d *datadriven.TestData) stri
100103

101104
// Now add the vectors one-by-one and ensure that distances and error
102105
// bounds stay the same.
103-
quantizedSet = quantizer.NewQuantizedVectorSet(vectors.Count, centroid)
106+
quantizedSet = quantizer.NewSet(vectors.Count, centroid)
104107
for i := range vectors.Count {
105108
quantizer.QuantizeInSet(&s.Workspace, quantizedSet, vectors.Slice(i, 1))
106109
}
@@ -172,6 +175,73 @@ func (s *testState) estimateDistances(t *testing.T, d *datadriven.TestData) stri
172175
return buf.String()
173176
}
174177

178+
func (s *testState) getCentroidDistances(t *testing.T, d *datadriven.TestData) string {
179+
var dims int
180+
var err error
181+
var vectors vector.Set
182+
for _, arg := range d.CmdArgs {
183+
switch arg.Key {
184+
case "dims":
185+
require.Len(t, arg.Vals, 1)
186+
dims, err = strconv.Atoi(arg.Vals[0])
187+
require.NoError(t, err)
188+
189+
// Parse the input vectors.
190+
vectors = vector.MakeSet(dims)
191+
for _, line := range strings.Split(d.Input, "\n") {
192+
line = strings.TrimSpace(line)
193+
if len(line) == 0 {
194+
continue
195+
}
196+
197+
vec, err := vector.ParseVector(line)
198+
require.NoError(t, err)
199+
vectors.Add(vec)
200+
}
201+
202+
default:
203+
t.Fatalf("unknown arg: %s", arg.Key)
204+
}
205+
}
206+
207+
var buf bytes.Buffer
208+
doTest := func(metric vecpb.DistanceMetric) {
209+
rabitQ := quantize.NewRaBitQuantizer(dims, 42, metric)
210+
quantizedSet := rabitQ.Quantize(&s.Workspace, vectors).(*quantize.RaBitQuantizedVectorSet)
211+
212+
buf.WriteString(" Centroid = ")
213+
utils.WriteVector(&buf, quantizedSet.Centroid, 4)
214+
buf.WriteByte('\n')
215+
216+
buf.WriteString(" Mean Centroid Distances = ")
217+
distances := make([]float32, quantizedSet.GetCount())
218+
rabitQ.GetCentroidDistances(quantizedSet, distances, false /* spherical */)
219+
utils.WriteVector(&buf, distances, 4)
220+
buf.WriteByte('\n')
221+
222+
buf.WriteString(" Spherical Centroid Distances = ")
223+
rabitQ.GetCentroidDistances(quantizedSet, distances, true /* spherical */)
224+
utils.WriteVector(&buf, distances, 4)
225+
buf.WriteByte('\n')
226+
}
227+
228+
buf.WriteString("L2Squared\n")
229+
doTest(vecpb.L2SquaredDistance)
230+
231+
buf.WriteString("InnerProduct\n")
232+
doTest(vecpb.InnerProductDistance)
233+
234+
// For cosine distance, normalize the input vectors.
235+
for i := range vectors.Count {
236+
num32.Normalize(vectors.At(i))
237+
}
238+
239+
buf.WriteString("Cosine\n")
240+
doTest(vecpb.CosineDistance)
241+
242+
return buf.String()
243+
}
244+
175245
func (s *testState) calculateRecall(t *testing.T, d *datadriven.TestData) string {
176246
var datasetName string
177247
randomize := false

pkg/sql/vecindex/cspann/quantize/rabitq.go

Lines changed: 79 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ type RaBitQuantizer struct {
5454
// to the statically-allocated arrays in this struct.
5555
type raBitQuantizedVector struct {
5656
RaBitQuantizedVectorSet
57-
codeCountStorage [1]uint32
58-
centroidDistanceStorage [1]float32
59-
dotProductStorage [1]float32
57+
codeCountStorage [1]uint32
58+
centroidDistanceStorage [1]float32
59+
quantizedDotProductStorage [1]float32
60+
centroidDotProductStorage [1]float32
6061
}
6162

6263
var _ Quantizer = (*RaBitQuantizer)(nil)
@@ -111,7 +112,7 @@ func (q *RaBitQuantizer) Quantize(w *workspace.T, vectors vector.Set) QuantizedV
111112
centroid = vectors.Centroid(make(vector.T, vectors.Dims))
112113
}
113114

114-
quantizedSet := q.NewQuantizedVectorSet(vectors.Count, centroid)
115+
quantizedSet := q.NewSet(vectors.Count, centroid)
115116
q.quantizeHelper(w, quantizedSet.(*RaBitQuantizedVectorSet), vectors)
116117
return quantizedSet
117118
}
@@ -123,32 +124,43 @@ func (q *RaBitQuantizer) QuantizeInSet(
123124
q.quantizeHelper(w, quantizedSet.(*RaBitQuantizedVectorSet), vectors)
124125
}
125126

126-
// NewQuantizedVectorSet implements the Quantizer interface
127-
func (q *RaBitQuantizer) NewQuantizedVectorSet(capacity int, centroid vector.T) QuantizedVectorSet {
128-
codeWidth := RaBitQCodeSetWidth(q.GetDims())
129-
dataBuffer := make([]uint64, 0, capacity*codeWidth)
127+
// NewSet implements the Quantizer interface
128+
func (q *RaBitQuantizer) NewSet(capacity int, centroid vector.T) QuantizedVectorSet {
129+
var vs *RaBitQuantizedVectorSet
130+
130131
if capacity <= 1 {
131132
// Special case capacity of zero or one by using in-line storage.
132-
var quantized raBitQuantizedVector
133-
quantized.Centroid = centroid
134-
quantized.Codes = MakeRaBitQCodeSetFromRawData(dataBuffer, codeWidth)
133+
quantized := &raBitQuantizedVector{}
135134
quantized.CodeCounts = quantized.codeCountStorage[:0]
136135
quantized.CentroidDistances = quantized.centroidDistanceStorage[:0]
137-
quantized.QuantizedDotProducts = quantized.dotProductStorage[:0]
138-
return &quantized.RaBitQuantizedVectorSet
139-
}
136+
quantized.QuantizedDotProducts = quantized.quantizedDotProductStorage[:0]
140137

141-
vs := &RaBitQuantizedVectorSet{
142-
Centroid: centroid,
143-
Codes: MakeRaBitQCodeSetFromRawData(dataBuffer, codeWidth),
144-
CodeCounts: make([]uint32, 0, capacity),
145-
CentroidDistances: make([]float32, 0, capacity),
146-
QuantizedDotProducts: make([]float32, 0, capacity),
138+
// L2Squared doesn't use this.
139+
if q.distanceMetric != vecpb.L2SquaredDistance {
140+
quantized.CentroidDotProducts = quantized.centroidDotProductStorage[:0]
141+
}
142+
vs = &quantized.RaBitQuantizedVectorSet
143+
} else {
144+
vs = &RaBitQuantizedVectorSet{
145+
CodeCounts: make([]uint32, 0, capacity),
146+
CentroidDistances: make([]float32, 0, capacity),
147+
QuantizedDotProducts: make([]float32, 0, capacity),
148+
}
149+
// L2Squared doesn't use these, so don't make extra allocation or calculation.
150+
if q.distanceMetric != vecpb.L2SquaredDistance {
151+
vs.CentroidDotProducts = make([]float32, 0, capacity)
152+
}
147153
}
148-
// L2Squared doesn't use this, so don't make extra allocation.
154+
155+
vs.Metric = q.distanceMetric
156+
vs.Centroid = centroid
157+
codeWidth := RaBitQCodeSetWidth(q.GetDims())
158+
dataBuffer := make([]uint64, 0, capacity*codeWidth)
159+
vs.Codes = MakeRaBitQCodeSetFromRawData(dataBuffer, codeWidth)
149160
if q.distanceMetric != vecpb.L2SquaredDistance {
150-
vs.CentroidDotProducts = make([]float32, 0, capacity)
161+
vs.CentroidNorm = num32.Norm(centroid)
151162
}
163+
152164
return vs
153165
}
154166

@@ -180,28 +192,7 @@ func (q *RaBitQuantizer) EstimateDistances(
180192

181193
if queryCentroidDistance == 0 {
182194
// The query vector is the centroid.
183-
switch q.distanceMetric {
184-
case vecpb.L2SquaredDistance:
185-
// The distance from the query to the data vectors are just the centroid
186-
// distances that have already been calculated, but just need to be
187-
// squared.
188-
num32.MulTo(distances, raBitSet.CentroidDistances, raBitSet.CentroidDistances)
189-
190-
case vecpb.InnerProductDistance:
191-
// The dot products between the centroid and the data vectors have
192-
// already been computed, just need to negate them.
193-
num32.ScaleTo(distances, -1, raBitSet.CentroidDotProducts)
194-
195-
case vecpb.CosineDistance:
196-
// All vectors have been normalized, so cosine distance = 1 - dot product.
197-
num32.ScaleTo(distances, -1, raBitSet.CentroidDotProducts)
198-
num32.AddConst(1, distances)
199-
200-
default:
201-
panic(errors.AssertionFailedf(
202-
"RaBitQuantizer does not support distance metric %s", q.distanceMetric))
203-
}
204-
195+
q.GetCentroidDistances(quantizedSet, distances, false /* spherical */)
205196
num32.Zero(errorBounds)
206197
return
207198
}
@@ -210,7 +201,7 @@ func (q *RaBitQuantizer) EstimateDistances(
210201
var squaredCentroidNorm, queryCentroidDotProduct float32
211202
if q.distanceMetric != vecpb.L2SquaredDistance {
212203
queryCentroidDotProduct = num32.Dot(queryVector, raBitSet.Centroid)
213-
squaredCentroidNorm = num32.SquaredNorm(raBitSet.Centroid)
204+
squaredCentroidNorm = raBitSet.CentroidNorm * raBitSet.CentroidNorm
214205
}
215206

216207
tempQueryUnitVector := tempQueryDiff
@@ -371,6 +362,48 @@ func (q *RaBitQuantizer) EstimateDistances(
371362
}
372363
}
373364

365+
// GetCentroidDistances implements the Quantizer interface.
366+
func (q *RaBitQuantizer) GetCentroidDistances(
367+
quantizedSet QuantizedVectorSet, distances []float32, spherical bool,
368+
) {
369+
raBitSet := quantizedSet.(*RaBitQuantizedVectorSet)
370+
371+
switch q.distanceMetric {
372+
case vecpb.L2SquaredDistance:
373+
// The distance from the query to the data vectors are just the centroid
374+
// distances that have already been calculated, but just need to be
375+
// squared.
376+
num32.MulTo(distances, raBitSet.CentroidDistances, raBitSet.CentroidDistances)
377+
378+
case vecpb.InnerProductDistance:
379+
// Need to negate precomputed centroid dot products to compute inner
380+
// product distance.
381+
multiplier := float32(-1)
382+
if spherical && raBitSet.CentroidNorm != 0 {
383+
// Convert the mean centroid dot product into a spherical centroid
384+
// dot product by dividing by the centroid's norm.
385+
multiplier /= raBitSet.CentroidNorm
386+
}
387+
num32.ScaleTo(distances, multiplier, raBitSet.CentroidDotProducts)
388+
389+
case vecpb.CosineDistance:
390+
// Cosine distance = 1 - dot product when vectors are normalized. The
391+
// precomputed centroid dot products were computed with normalized data
392+
// vectors, but the centroid was not normalized. Do that now by dividing
393+
// the dot products by the centroid's norm. Also negate the result.
394+
multiplier := float32(-1)
395+
if raBitSet.CentroidNorm != 0 {
396+
multiplier /= raBitSet.CentroidNorm
397+
}
398+
num32.ScaleTo(distances, multiplier, raBitSet.CentroidDotProducts)
399+
num32.AddConst(1, distances)
400+
401+
default:
402+
panic(errors.AssertionFailedf(
403+
"RaBitQuantizer does not support distance metric %s", q.distanceMetric))
404+
}
405+
}
406+
374407
// quantizeHelper quantizes the given set of vectors and adds the quantization
375408
// information to the provided quantized vector set.
376409
func (q *RaBitQuantizer) quantizeHelper(
@@ -383,7 +416,7 @@ func (q *RaBitQuantizer) quantizeHelper(
383416
// Extend any existing slices in the vector set.
384417
count := vectors.Count
385418
oldCount := qs.GetCount()
386-
qs.AddUndefined(count, q.distanceMetric)
419+
qs.AddUndefined(count)
387420

388421
// L2Squared doesn't use this, so don't store it.
389422
if q.distanceMetric != vecpb.L2SquaredDistance {

0 commit comments

Comments
 (0)