Skip to content

Commit f483ecb

Browse files
craig[bot]andy-kimball
andcommitted
Merge #147027
147027: vecindex: fix bugs and clean up r=drewkimball a=andy-kimball #### quantize: fix calculation when vectors are added incrementally The quantization code is mis-calculating the centroid dot product when vectors are added incrementally to a quantized set that uses a non-Euclidean distance metric. Rather than appending the dot products to the end of the CentroidDotProducts slice, it was overwriting the initial elements of the slice. Fix the bug and add more tests that ensure it won't happen again. #### kmeans: fix small issue with convergence calculation Update the convergence calculation to match the logic in scikit-learn, which sums the shifts across all centroids before comparing to the tolerance value. Also improve comments and add an interesting test case. #### quantize: remove the QuantizedVectorSet.GetCentroid method Remove GetCentroid method from the QuantizedVectorSet interface. Having it present on the interface forces every quantizer to store it, even those (like UnQuantizer) that don't need it. It's also redundant in most contexts with Partition.Centroid(), so it's cleaner to have code use that consistently. Co-authored-by: Andrew Kimball <[email protected]>
2 parents 812a56c + 2ce20a2 commit f483ecb

18 files changed

+115
-109
lines changed

pkg/sql/vecindex/cspann/commontest/utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func ValidatePartitionsEqual(t *testing.T, l, r *cspann.Partition) {
7272
require.Equal(t, l.Level(), r.Level(), "levels do not match")
7373
require.Equal(t, l.ChildKeys(), r.ChildKeys(), "childKeys do not match")
7474
require.Equal(t, l.ValueBytes(), r.ValueBytes(), "valueBytes do not match")
75-
require.Equal(t, q1.GetCentroid(), q2.GetCentroid(), "centroids do not match")
75+
require.Equal(t, l.Centroid, r.Centroid, "centroids do not match")
7676
require.Equal(t, q1.GetCount(), q2.GetCount(), "counts do not match")
7777
if eq, ok := q1.(equaler); ok {
7878
require.True(t, eq.Equal(q2))

pkg/sql/vecindex/cspann/kmeans.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ func (km *BalancedKmeans) ComputeCentroids(
7878
tolerance := km.calculateTolerance(vectors)
7979

8080
// Pick 2 centroids to start, using the K-means++ algorithm.
81+
// TOOD(andyk): We should consider adding an outer loop here to generate new
82+
// random centroids in the case where more than 2/3 the vectors are assigned
83+
// to one of the partitions. If we're that unbalanced, it might just be
84+
// because we picked bad starting centroids and retrying could correct that.
8185
tempLeftCentroid := km.Workspace.AllocVector(vectors.Dims)
8286
defer km.Workspace.FreeVector(tempLeftCentroid)
8387
newLeftCentroid := tempLeftCentroid
@@ -109,12 +113,15 @@ func (km *BalancedKmeans) ComputeCentroids(
109113
}
110114
calcPartitionCentroid(km.DistanceMetric, vectors, rightOffsets, newRightCentroid)
111115

112-
// Check if algorithm has converged.
113-
// TODO(andyk): Is there anything better than using Euclidean distances to
114-
// check for convergence?
116+
// Check for convergence using the scikit-learn algorithm.
117+
// NOTE: This uses Euclidean distance, even when using spherical centroids
118+
// with Cosine or InnerProduct distances. This approach mirrors the
119+
// spherecluster library. Since spherical centroids are always normalized
120+
// (unit vectors), the squared Euclidean distance is 2x the Cosine or
121+
// InnerProduct distance, so it's a reasonable convergence check.
115122
leftCentroidShift := num32.L2SquaredDistance(leftCentroid, newLeftCentroid)
116123
rightCentroidShift := num32.L2SquaredDistance(rightCentroid, newRightCentroid)
117-
if leftCentroidShift <= tolerance && rightCentroidShift <= tolerance {
124+
if leftCentroidShift+rightCentroidShift <= tolerance {
118125
break
119126
}
120127

pkg/sql/vecindex/cspann/kmeans_test.go

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ func TestBalancedKMeans(t *testing.T) {
3434
}
3535

3636
workspace := &workspace.T{}
37-
rng := rand.New(rand.NewSource(42))
38-
dataset := testutils.LoadDataset(t, testutils.ImagesDataset)
37+
images := testutils.LoadDataset(t, testutils.ImagesDataset)
38+
laion := testutils.LoadDataset(t, testutils.LaionDataset)
3939

4040
testCases := []struct {
4141
desc string
@@ -46,6 +46,7 @@ func TestBalancedKMeans(t *testing.T) {
4646
leftCentroid vector.T
4747
rightCentroid vector.T
4848
skipPinTest bool
49+
testUnbalanced bool
4950
}{
5051
{
5152
desc: "partition vector set with only 2 elements",
@@ -128,8 +129,8 @@ func TestBalancedKMeans(t *testing.T) {
128129
9, -14, 20,
129130
5, 9, 4,
130131
}, 3),
131-
leftOffsets: []uint64{0, 3},
132-
rightOffsets: []uint64{1, 2, 4},
132+
leftOffsets: []uint64{1, 2},
133+
rightOffsets: []uint64{0, 3, 4},
133134
},
134135
{
135136
desc: "cosine distance",
@@ -147,7 +148,7 @@ func TestBalancedKMeans(t *testing.T) {
147148
{
148149
desc: "high-dimensional unit vectors, Euclidean distance",
149150
distanceMetric: vecdist.L2Squared,
150-
vectors: dataset.Slice(0, 100),
151+
vectors: images.Slice(0, 100),
151152
// It's challenging to test pinLeftCentroid for this case, due to the
152153
// inherent randomness of the K-means++ algorithm. The other test cases
153154
// should be sufficient to test that, however.
@@ -156,22 +157,39 @@ func TestBalancedKMeans(t *testing.T) {
156157
{
157158
desc: "high-dimensional unit vectors, InnerProduct distance",
158159
distanceMetric: vecdist.InnerProduct,
159-
vectors: dataset.Slice(0, 100),
160+
vectors: images.Slice(0, 100),
160161
skipPinTest: true,
161162
},
162163
{
163164
desc: "high-dimensional unit vectors, Cosine distance",
164165
distanceMetric: vecdist.Cosine,
165-
vectors: dataset.Slice(0, 100),
166+
vectors: images.Slice(0, 100),
166167
skipPinTest: true,
167168
},
169+
{
170+
// Note that laion.Slice(0, 100) actually fails the check that vectors
171+
// in the left partition are closer to the left centroid than vectors
172+
// in the right partition. This is because K-means++ happens to pick
173+
// bad centroids that result in > 2/3rd the vectors being closer to
174+
// the right centroid. In that case, the BalancedKMeans class will
175+
// deliberately move vectors to the left partition, even though they
176+
// are closer to the right partition, in order to obey the balancing
177+
// constraint.
178+
desc: "different dataset, InnerProduct distance",
179+
distanceMetric: vecdist.InnerProduct,
180+
vectors: laion.Slice(0, 100),
181+
skipPinTest: true,
182+
testUnbalanced: true,
183+
},
168184
}
169185

170186
for _, tc := range testCases {
171187
t.Run(tc.desc, func(t *testing.T) {
188+
// Re-initialize rng for each iteration so that order of test cases
189+
// doesn't matter.
172190
kmeans := BalancedKmeans{
173191
Workspace: workspace,
174-
Rand: rng,
192+
Rand: rand.New(rand.NewSource(42)),
175193
DistanceMetric: tc.distanceMetric,
176194
}
177195

@@ -217,7 +235,11 @@ func TestBalancedKMeans(t *testing.T) {
217235
// partition than those in the right partition.
218236
leftMean := calcMeanDistance(tc.distanceMetric, tc.vectors, leftCentroid, leftOffsets)
219237
rightMean := calcMeanDistance(tc.distanceMetric, tc.vectors, leftCentroid, rightOffsets)
220-
require.LessOrEqual(t, leftMean, rightMean)
238+
if tc.testUnbalanced {
239+
require.Greater(t, leftMean, rightMean)
240+
} else {
241+
require.LessOrEqual(t, leftMean, rightMean)
242+
}
221243

222244
if !tc.skipPinTest {
223245
// Check that pinning the left centroid returns the same right centroid.

pkg/sql/vecindex/cspann/memstore/memstore.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,17 +695,19 @@ func Load(data []byte) (*Store, error) {
695695

696696
var quantizer quantize.Quantizer
697697
var quantizedSet quantize.QuantizedVectorSet
698+
var centroid vector.T
698699
if partitionProto.RaBitQ != nil {
699700
quantizer = raBitQuantizer
700701
quantizedSet = partitionProto.RaBitQ
702+
centroid = partitionProto.RaBitQ.Centroid
701703
} else {
702704
quantizer = unquantizer
703705
quantizedSet = partitionProto.UnQuantized
704706
}
705707

706708
metadata := cspann.PartitionMetadata{
707709
Level: partitionProto.Metadata.Level,
708-
Centroid: quantizedSet.GetCentroid(),
710+
Centroid: centroid,
709711
StateDetails: cspann.PartitionStateDetails{
710712
State: partitionProto.Metadata.State,
711713
Target1: partitionProto.Metadata.Target1,

pkg/sql/vecindex/cspann/memstore/memstore_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ func TestInMemoryStoreMarshalling(t *testing.T) {
253253
cspann.MakeReadyPartitionMetadata(1, centroid),
254254
unquantizer,
255255
&quantize.UnQuantizedVectorSet{
256-
Centroid: centroid,
257256
Vectors: vector.Set{
258257
Dims: 2,
259258
Count: 3,
@@ -270,7 +269,6 @@ func TestInMemoryStoreMarshalling(t *testing.T) {
270269
cspann.MakeReadyPartitionMetadata(2, centroid),
271270
raBitQuantizer,
272271
&quantize.UnQuantizedVectorSet{
273-
Centroid: centroid,
274272
Vectors: vector.Set{
275273
Dims: 2,
276274
Count: 3,

pkg/sql/vecindex/cspann/partition.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (p *Partition) QuantizedSet() quantize.QuantizedVectorSet {
127127
// Centroid is the full-sized centroid vector for this partition.
128128
// NOTE: The centroid is immutable and therefore this method is thread-safe.
129129
func (p *Partition) Centroid() vector.T {
130-
return p.quantizedSet.GetCentroid()
130+
return p.metadata.Centroid
131131
}
132132

133133
// ChildKeys point to the location of the full-size vectors that are quantized
@@ -271,7 +271,7 @@ func (p *Partition) Find(childKey ChildKey) int {
271271
// vectors that were cleared. The centroid stays the same.
272272
func (p *Partition) Clear() int {
273273
count := len(p.childKeys)
274-
p.quantizedSet.Clear(p.quantizedSet.GetCentroid())
274+
p.quantizedSet.Clear(p.metadata.Centroid)
275275
clear(p.childKeys)
276276
p.childKeys = p.childKeys[:0]
277277
clear(p.valueBytes)

pkg/sql/vecindex/cspann/partition_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func TestPartition(t *testing.T) {
5555
quantizedSet := quantizer.Quantize(&workspace, vectors)
5656
childKeys := []ChildKey{childKey10, childKey20, childKey30}
5757
valueBytes := []ValueBytes{valueBytes10, valueBytes20, valueBytes30}
58-
metadata := MakeReadyPartitionMetadata(1, quantizedSet.GetCentroid())
58+
metadata := MakeReadyPartitionMetadata(1, vectors.Centroid(make(vector.T, vectors.Dims)))
5959
return NewPartition(metadata, quantizer, quantizedSet, childKeys, valueBytes)
6060
}
6161

pkg/sql/vecindex/cspann/quantize/quantize.proto

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,6 @@ message RaBitQuantizedVectorSet {
6464
message UnQuantizedVectorSet {
6565
option (gogoproto.equal) = true;
6666

67-
// Centroid is the average of vectors in the set, representing its "center of
68-
// mass". Note that the centroid is computed when a vector set is created and
69-
// is not updated when vectors are added or removed.
70-
// NOTE: By default, this is the mean centroid for the L2Squared distance
71-
// metric, but is the spherical centroid for InnerProduct and Cosine metrics.
72-
// The spherical centroid is simply the mean centroid, but normalized.
73-
// NOTE: The centroid should be treated as immutable.
74-
repeated float centroid = 1;
75-
reserved 2; // CentroidDistances
7667
// Vectors is the set of original full-size vectors.
77-
cockroach.util.vector.Set vectors = 3 [(gogoproto.nullable) = false];
68+
cockroach.util.vector.Set vectors = 1 [(gogoproto.nullable) = false];
7869
}

pkg/sql/vecindex/cspann/quantize/quantizer.go

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,6 @@ type QuantizedVectorSet interface {
7070
// GetCount returns the number of quantized vectors in the set.
7171
GetCount() int
7272

73-
// GetCentroid returns the full-size centroid vector for the set. The
74-
// centroid is the average of the vectors across all dimensions.
75-
// NOTE: By default, this is the mean centroid for the L2Squared distance
76-
// metric, but is the spherical centroid for InnerProduct and Cosine metrics.
77-
// The spherical centroid is simply the mean centroid, but normalized.
78-
// NOTE: This centroid is calculated once, when the set is first created. It
79-
// is not updated when quantized vectors are added to or removed from the set.
80-
// Since it is immutable, this method is thread-safe.
81-
GetCentroid() vector.T
82-
8373
// ReplaceWithLast removes the quantized vector at the given offset from the
8474
// set, replacing it with the last quantized vector in the set. The modified
8575
// set has one less element and the last quantized vector's position changes.

pkg/sql/vecindex/cspann/quantize/quantizer_test.go

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,33 +88,63 @@ func (s *testState) estimateDistances(t *testing.T, d *datadriven.TestData) stri
8888
}
8989
}
9090

91+
addVectors := func(
92+
quantizer quantize.Quantizer, centroid vector.T, vectors vector.Set,
93+
) (quantizedSet quantize.QuantizedVectorSet, distances, errorBounds []float32) {
94+
// Quantize all vectors at once.
95+
quantizedSet = quantizer.Quantize(&s.Workspace, vectors)
96+
distances = make([]float32, quantizedSet.GetCount())
97+
errorBounds = make([]float32, quantizedSet.GetCount())
98+
quantizer.EstimateDistances(
99+
&s.Workspace, quantizedSet, queryVector, distances, errorBounds)
100+
101+
// Now add the vectors one-by-one and ensure that distances and error
102+
// bounds stay the same.
103+
quantizedSet = quantizer.NewQuantizedVectorSet(vectors.Count, centroid)
104+
for i := range vectors.Count {
105+
quantizer.QuantizeInSet(&s.Workspace, quantizedSet, vectors.Slice(i, 1))
106+
}
107+
distances2 := make([]float32, quantizedSet.GetCount())
108+
errorBounds2 := make([]float32, quantizedSet.GetCount())
109+
quantizer.EstimateDistances(
110+
&s.Workspace, quantizedSet, queryVector, distances2, errorBounds2)
111+
require.Equal(t, distances2, distances)
112+
require.Equal(t, errorBounds2, errorBounds)
113+
114+
return quantizedSet, distances, errorBounds
115+
}
116+
91117
var buf bytes.Buffer
92118
doTest := func(metric vecdist.Metric, prec int) {
93-
unquantizer := quantize.NewUnQuantizer(len(queryVector), metric)
94-
unQuantizedSet := unquantizer.Quantize(&s.Workspace, vectors)
95-
exact := make([]float32, unQuantizedSet.GetCount())
96-
errorBounds := make([]float32, unQuantizedSet.GetCount())
97-
unquantizer.EstimateDistances(
98-
&s.Workspace, unQuantizedSet, queryVector, exact, errorBounds)
119+
centroid := vectors.Centroid(make(vector.T, vectors.Dims))
120+
if metric == vecdist.InnerProduct || metric == vecdist.Cosine {
121+
// Use spherical centroid for inner product and cosine distances,
122+
// which is the mean centroid, but normalized.
123+
num32.Normalize(centroid)
124+
}
125+
126+
// Test UnQuantizer.
127+
unQuantizer := quantize.NewUnQuantizer(len(queryVector), metric)
128+
quantizedSet, exact, errorBounds := addVectors(unQuantizer, centroid, vectors)
129+
unQuantizedSet := quantizedSet.(*quantize.UnQuantizedVectorSet)
99130
for _, error := range errorBounds {
131+
// ErrorBounds should always be zero for UnQuantizer.
100132
require.Zero(t, error)
101133
}
102134

135+
// Test RaBitQuantizer.
103136
rabitQ := quantize.NewRaBitQuantizer(len(queryVector), 42, metric)
104-
rabitQSet := rabitQ.Quantize(&s.Workspace, vectors)
105-
estimated := make([]float32, rabitQSet.GetCount())
106-
rabitQ.EstimateDistances(
107-
&s.Workspace, rabitQSet, queryVector, estimated, errorBounds)
137+
quantizedSet, estimated, errorBounds := addVectors(rabitQ, centroid, vectors)
138+
rabitQSet := quantizedSet.(*quantize.RaBitQuantizedVectorSet)
108139

109-
// UnQuantizer and RaBitQuantizer should have calculated same centroid.
110-
require.Equal(t, unQuantizedSet.GetCentroid(), rabitQSet.GetCentroid())
140+
require.Equal(t, unQuantizedSet.GetCount(), rabitQSet.GetCount())
111141

112142
buf.WriteString(" Query = ")
113143
utils.WriteVector(&buf, queryVector, 4)
114144
buf.WriteByte('\n')
115145

116146
buf.WriteString(" Centroid = ")
117-
utils.WriteVector(&buf, rabitQSet.GetCentroid(), 4)
147+
utils.WriteVector(&buf, rabitQSet.Centroid, 4)
118148
buf.WriteByte('\n')
119149

120150
for i := range vectors.Count {

0 commit comments

Comments
 (0)