Skip to content

Commit adf7f39

Browse files
committed
quantize: add GetCentroidDistances method to Quantizer interface
GetCentroidDistances returns the exact distance of each vector in a quantized set to the centroid of that set, according to the quantizer's distance metric (e.g. L2Squared or Cosine). It also supports distance to the spherical centroid for Cosine and InnerProduct metrics. This method will be used in a later PR by the partition split code. Epic: CRDB-42943 Release note: None
1 parent 67cb06a commit adf7f39

File tree

10 files changed

+242
-35
lines changed

10 files changed

+242
-35
lines changed

pkg/sql/vecindex/cspann/quantize/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,6 @@ go_test(
7171
"//pkg/util/vector",
7272
"@com_github_cockroachdb_datadriven//:datadriven",
7373
"@com_github_stretchr_testify//require",
74+
"@org_gonum_v1_gonum//floats/scalar",
7475
],
7576
)

pkg/sql/vecindex/cspann/quantize/quantize.proto

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ message RaBitQuantizedVectorSet {
3535
// Centroid is the average of vectors in the set, representing its "center of
3636
// mass". Note that the centroid is computed when a vector set is created and
3737
// is not updated when vectors are added or removed.
38-
// NOTE: By default, this is the mean centroid for the L2Squared distance
39-
// metric, but is the spherical centroid for InnerProduct and Cosine metrics.
40-
// The spherical centroid is simply the mean centroid, but normalized.
38+
// NOTE: This is always the mean centroid, regardless of the distance metric.
39+
// The caller is responsible for converting this to a spherical centroid when
40+
// that's needed.
4141
// NOTE: The centroid should be treated as immutable.
4242
repeated float centroid = 1;
4343
// Codes is a set of RaBitQ quantization codes, with one code per quantized
@@ -55,7 +55,11 @@ message RaBitQuantizedVectorSet {
5555
repeated float quantized_dot_products = 5;
5656
// CentroidDotProducts is a slice of the exact inner products between the
5757
// original full-size vectors and the centroid.
58+
// NOTE: This is always nil when using the L2Squared distance metric.
5859
repeated float centroid_dot_products = 6;
60+
// CentroidNorm is the L2 norm of the mean centroid.
61+
// NOTE: This is always nil when using the L2Squared distance metric.
62+
float centroid_norm = 7;
5963
}
6064

6165
// UnQuantizedVectorSet trivially implements the QuantizedVectorSet interface,

pkg/sql/vecindex/cspann/quantize/quantizer.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ type Quantizer interface {
6161
distances []float32,
6262
errorBounds []float32,
6363
)
64+
65+
// GetCentroidDistances returns the exact distance of each vector in
66+
// "quantizedSet" from that set's centroid, according to the quantizer's
67+
// distance metric (e.g. L2Squared or Cosine). By default, it returns
68+
// distances to the mean centroid. However, if "spherical" is true and the
69+
// distance metric is Cosine or InnerProduct, then it returns distances to
70+
// the spherical centroid instead.
71+
//
72+
// The caller is responsible for allocating the "distances" slice with length
73+
// equal to the number of quantized vectors in "quantizedSet". The centroid
74+
// distances will be copied into that slice.
75+
GetCentroidDistances(quantizedSet QuantizedVectorSet, distances []float32, spherical bool)
6476
}
6577

6678
// QuantizedVectorSet is the compressed form of an original set of full-size

pkg/sql/vecindex/cspann/quantize/quantizer_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ func TestQuantizers(t *testing.T) {
4141
case "estimate-distances":
4242
return state.estimateDistances(t, d)
4343

44+
case "get-centroid-distances":
45+
return state.getCentroidDistances(t, d)
46+
4447
case "calculate-recall":
4548
return state.calculateRecall(t, d)
4649

@@ -172,6 +175,73 @@ func (s *testState) estimateDistances(t *testing.T, d *datadriven.TestData) stri
172175
return buf.String()
173176
}
174177

178+
func (s *testState) getCentroidDistances(t *testing.T, d *datadriven.TestData) string {
179+
var dims int
180+
var err error
181+
var vectors vector.Set
182+
for _, arg := range d.CmdArgs {
183+
switch arg.Key {
184+
case "dims":
185+
require.Len(t, arg.Vals, 1)
186+
dims, err = strconv.Atoi(arg.Vals[0])
187+
require.NoError(t, err)
188+
189+
// Parse the input vectors.
190+
vectors = vector.MakeSet(dims)
191+
for _, line := range strings.Split(d.Input, "\n") {
192+
line = strings.TrimSpace(line)
193+
if len(line) == 0 {
194+
continue
195+
}
196+
197+
vec, err := vector.ParseVector(line)
198+
require.NoError(t, err)
199+
vectors.Add(vec)
200+
}
201+
202+
default:
203+
t.Fatalf("unknown arg: %s", arg.Key)
204+
}
205+
}
206+
207+
var buf bytes.Buffer
208+
doTest := func(metric vecpb.DistanceMetric) {
209+
rabitQ := quantize.NewRaBitQuantizer(dims, 42, metric)
210+
quantizedSet := rabitQ.Quantize(&s.Workspace, vectors).(*quantize.RaBitQuantizedVectorSet)
211+
212+
buf.WriteString(" Centroid = ")
213+
utils.WriteVector(&buf, quantizedSet.Centroid, 4)
214+
buf.WriteByte('\n')
215+
216+
buf.WriteString(" Mean Centroid Distances = ")
217+
distances := make([]float32, quantizedSet.GetCount())
218+
rabitQ.GetCentroidDistances(quantizedSet, distances, false /* spherical */)
219+
utils.WriteVector(&buf, distances, 4)
220+
buf.WriteByte('\n')
221+
222+
buf.WriteString(" Spherical Centroid Distances = ")
223+
rabitQ.GetCentroidDistances(quantizedSet, distances, true /* spherical */)
224+
utils.WriteVector(&buf, distances, 4)
225+
buf.WriteByte('\n')
226+
}
227+
228+
buf.WriteString("L2Squared\n")
229+
doTest(vecpb.L2SquaredDistance)
230+
231+
buf.WriteString("InnerProduct\n")
232+
doTest(vecpb.InnerProductDistance)
233+
234+
// For cosine distance, normalize the input vectors.
235+
for i := range vectors.Count {
236+
num32.Normalize(vectors.At(i))
237+
}
238+
239+
buf.WriteString("Cosine\n")
240+
doTest(vecpb.CosineDistance)
241+
242+
return buf.String()
243+
}
244+
175245
func (s *testState) calculateRecall(t *testing.T, d *datadriven.TestData) string {
176246
var datasetName string
177247
randomize := false

pkg/sql/vecindex/cspann/quantize/rabitq.go

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ type RaBitQuantizer struct {
5454
// to the statically-allocated arrays in this struct.
5555
type raBitQuantizedVector struct {
5656
RaBitQuantizedVectorSet
57-
codeCountStorage [1]uint32
58-
centroidDistanceStorage [1]float32
59-
dotProductStorage [1]float32
57+
codeCountStorage [1]uint32
58+
centroidDistanceStorage [1]float32
59+
quantizedDotProductStorage [1]float32
60+
centroidDotProductStorage [1]float32
6061
}
6162

6263
var _ Quantizer = (*RaBitQuantizer)(nil)
@@ -134,7 +135,13 @@ func (q *RaBitQuantizer) NewQuantizedVectorSet(capacity int, centroid vector.T)
134135
quantized.Codes = MakeRaBitQCodeSetFromRawData(dataBuffer, codeWidth)
135136
quantized.CodeCounts = quantized.codeCountStorage[:0]
136137
quantized.CentroidDistances = quantized.centroidDistanceStorage[:0]
137-
quantized.QuantizedDotProducts = quantized.dotProductStorage[:0]
138+
quantized.QuantizedDotProducts = quantized.quantizedDotProductStorage[:0]
139+
140+
// L2Squared doesn't use these, so don't make extra calculations.
141+
if q.distanceMetric != vecpb.L2SquaredDistance {
142+
quantized.CentroidDotProducts = quantized.centroidDotProductStorage[:0]
143+
quantized.CentroidNorm = num32.Norm(centroid)
144+
}
138145
return &quantized.RaBitQuantizedVectorSet
139146
}
140147

@@ -145,9 +152,10 @@ func (q *RaBitQuantizer) NewQuantizedVectorSet(capacity int, centroid vector.T)
145152
CentroidDistances: make([]float32, 0, capacity),
146153
QuantizedDotProducts: make([]float32, 0, capacity),
147154
}
148-
// L2Squared doesn't use this, so don't make extra allocation.
155+
// L2Squared doesn't use these, so don't make extra allocation or calculation.
149156
if q.distanceMetric != vecpb.L2SquaredDistance {
150157
vs.CentroidDotProducts = make([]float32, 0, capacity)
158+
vs.CentroidNorm = num32.Norm(centroid)
151159
}
152160
return vs
153161
}
@@ -180,28 +188,7 @@ func (q *RaBitQuantizer) EstimateDistances(
180188

181189
if queryCentroidDistance == 0 {
182190
// The query vector is the centroid.
183-
switch q.distanceMetric {
184-
case vecpb.L2SquaredDistance:
185-
// The distance from the query to the data vectors are just the centroid
186-
// distances that have already been calculated, but just need to be
187-
// squared.
188-
num32.MulTo(distances, raBitSet.CentroidDistances, raBitSet.CentroidDistances)
189-
190-
case vecpb.InnerProductDistance:
191-
// The dot products between the centroid and the data vectors have
192-
// already been computed, just need to negate them.
193-
num32.ScaleTo(distances, -1, raBitSet.CentroidDotProducts)
194-
195-
case vecpb.CosineDistance:
196-
// All vectors have been normalized, so cosine distance = 1 - dot product.
197-
num32.ScaleTo(distances, -1, raBitSet.CentroidDotProducts)
198-
num32.AddConst(1, distances)
199-
200-
default:
201-
panic(errors.AssertionFailedf(
202-
"RaBitQuantizer does not support distance metric %s", q.distanceMetric))
203-
}
204-
191+
q.GetCentroidDistances(quantizedSet, distances, false /* spherical */)
205192
num32.Zero(errorBounds)
206193
return
207194
}
@@ -210,7 +197,7 @@ func (q *RaBitQuantizer) EstimateDistances(
210197
var squaredCentroidNorm, queryCentroidDotProduct float32
211198
if q.distanceMetric != vecpb.L2SquaredDistance {
212199
queryCentroidDotProduct = num32.Dot(queryVector, raBitSet.Centroid)
213-
squaredCentroidNorm = num32.SquaredNorm(raBitSet.Centroid)
200+
squaredCentroidNorm = raBitSet.CentroidNorm * raBitSet.CentroidNorm
214201
}
215202

216203
tempQueryUnitVector := tempQueryDiff
@@ -371,6 +358,48 @@ func (q *RaBitQuantizer) EstimateDistances(
371358
}
372359
}
373360

361+
// GetCentroidDistances implements the Quantizer interface.
362+
func (q *RaBitQuantizer) GetCentroidDistances(
363+
quantizedSet QuantizedVectorSet, distances []float32, spherical bool,
364+
) {
365+
raBitSet := quantizedSet.(*RaBitQuantizedVectorSet)
366+
367+
switch q.distanceMetric {
368+
case vecpb.L2SquaredDistance:
369+
// The distance from the query to the data vectors are just the centroid
370+
// distances that have already been calculated, but just need to be
371+
// squared.
372+
num32.MulTo(distances, raBitSet.CentroidDistances, raBitSet.CentroidDistances)
373+
374+
case vecpb.InnerProductDistance:
375+
// Need to negate precomputed centroid dot products to compute inner
376+
// product distance.
377+
multiplier := float32(-1)
378+
if spherical && raBitSet.CentroidNorm != 0 {
379+
// Convert the mean centroid dot product into a spherical centroid
380+
// dot product by dividing by the centroid's norm.
381+
multiplier /= raBitSet.CentroidNorm
382+
}
383+
num32.ScaleTo(distances, multiplier, raBitSet.CentroidDotProducts)
384+
385+
case vecpb.CosineDistance:
386+
// Cosine distance = 1 - dot product when vectors are normalized. The
387+
// precomputed centroid dot products were computed with normalized data
388+
// vectors, but the centroid was not normalized. Do that now by dividing
389+
// the dot products by the centroid's norm. Also negate the result.
390+
multiplier := float32(-1)
391+
if raBitSet.CentroidNorm != 0 {
392+
multiplier /= raBitSet.CentroidNorm
393+
}
394+
num32.ScaleTo(distances, multiplier, raBitSet.CentroidDotProducts)
395+
num32.AddConst(1, distances)
396+
397+
default:
398+
panic(errors.AssertionFailedf(
399+
"RaBitQuantizer does not support distance metric %s", q.distanceMetric))
400+
}
401+
}
402+
374403
// quantizeHelper quantizes the given set of vectors and adds the quantization
375404
// information to the provided quantized vector set.
376405
func (q *RaBitQuantizer) quantizeHelper(

pkg/sql/vecindex/cspann/quantize/rabitq_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/cockroachdb/cockroach/pkg/util/num32"
1616
"github.com/cockroachdb/cockroach/pkg/util/vector"
1717
"github.com/stretchr/testify/require"
18+
"gonum.org/v1/gonum/floats/scalar"
1819
)
1920

2021
// Basic tests.
@@ -77,20 +78,21 @@ func TestRaBitQuantizerSimple(t *testing.T) {
7778
})
7879

7980
t.Run("empty quantized set with capacity", func(t *testing.T) {
80-
quantizer := NewRaBitQuantizer(65, 42, vecpb.L2SquaredDistance)
81+
quantizer := NewRaBitQuantizer(65, 42, vecpb.InnerProductDistance)
8182
centroid := make([]float32, 65)
8283
for i := range centroid {
8384
centroid[i] = float32(i)
8485
}
85-
quantizedSet := quantizer.NewQuantizedVectorSet(
86-
5, centroid).(*RaBitQuantizedVectorSet)
86+
quantizedSet := quantizer.NewQuantizedVectorSet(5, centroid).(*RaBitQuantizedVectorSet)
8787
require.Equal(t, centroid, quantizedSet.Centroid)
8888
require.Equal(t, 0, quantizedSet.Codes.Count)
8989
require.Equal(t, 2, quantizedSet.Codes.Width)
9090
require.Equal(t, 10, cap(quantizedSet.Codes.Data))
9191
require.Equal(t, 5, cap(quantizedSet.CodeCounts))
9292
require.Equal(t, 5, cap(quantizedSet.CentroidDistances))
9393
require.Equal(t, 5, cap(quantizedSet.QuantizedDotProducts))
94+
require.Equal(t, 5, cap(quantizedSet.CentroidDotProducts))
95+
require.Equal(t, float64(299.07), scalar.Round(float64(quantizedSet.CentroidNorm), 2))
9496
})
9597
}
9698

pkg/sql/vecindex/cspann/quantize/rabitqpb.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
1313
"github.com/cockroachdb/cockroach/pkg/util/buildutil"
14+
"github.com/cockroachdb/cockroach/pkg/util/num32"
1415
"github.com/cockroachdb/cockroach/pkg/util/vector"
1516
"github.com/cockroachdb/errors"
1617
)
@@ -140,6 +141,7 @@ func (vs *RaBitQuantizedVectorSet) Clone() QuantizedVectorSet {
140141
CentroidDistances: slices.Clone(vs.CentroidDistances),
141142
QuantizedDotProducts: slices.Clone(vs.QuantizedDotProducts),
142143
CentroidDotProducts: slices.Clone(vs.CentroidDotProducts),
144+
CentroidNorm: vs.CentroidNorm,
143145
}
144146
}
145147

@@ -165,6 +167,9 @@ func (vs *RaBitQuantizedVectorSet) Clear(centroid vector.T) {
165167
vs.CentroidDistances = vs.CentroidDistances[:0]
166168
vs.QuantizedDotProducts = vs.QuantizedDotProducts[:0]
167169
vs.CentroidDotProducts = vs.CentroidDotProducts[:0]
170+
if &vs.Centroid[0] != &centroid[0] {
171+
vs.CentroidNorm = num32.Norm(centroid)
172+
}
168173
}
169174

170175
// AddUndefined adds the given number of quantized vectors to this set. The new

pkg/sql/vecindex/cspann/quantize/testdata/estimate-distances.ddt

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,28 @@ Cosine
103103
(-1, 0): exact is 2, estimate is 2 ± 0.7071
104104
(1, 0): exact is 0, estimate is 0 ± 0.7071
105105

106-
# Query vector is equal to the centroid.
106+
# Query vector is equal to the centroid at the origin.
107+
estimate-distances query=[0,0]
108+
[-2,0]
109+
[2,0]
110+
----
111+
L2Squared
112+
Query = (0, 0)
113+
Centroid = (0, 0)
114+
(-2, 0): exact is 4, estimate is 4
115+
(2, 0): exact is 4, estimate is 4
116+
InnerProduct
117+
Query = (0, 0)
118+
Centroid = (0, 0)
119+
(-2, 0): exact is 0, estimate is 0
120+
(2, 0): exact is 0, estimate is 0
121+
Cosine
122+
Query = (0, 0)
123+
Centroid = (0, 0)
124+
(-1, 0): exact is 1, estimate is 1
125+
(1, 0): exact is 1, estimate is 1
126+
127+
# Query vector is equal to the centroid at a non-origin point.
107128
estimate-distances query=[2,2]
108129
[0,2]
109130
[4,2]

0 commit comments

Comments
 (0)