Skip to content

Commit 6d4850f

Browse files
committed
quantize: track distance metric in RaBitQuantizedSet
Include the distance metric in RaBitQuantizedSet, since its needed for methods like Clear and AddUndefined. Also, clean up the encoding methods for RaBitQVector a bit. Epic: CRDB-42943 Release note: None
1 parent 35f1106 commit 6d4850f

File tree

9 files changed

+61
-79
lines changed

9 files changed

+61
-79
lines changed

pkg/sql/vecindex/cspann/quantize/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ proto_library(
88
strip_import_prefix = "/pkg",
99
visibility = ["//visibility:public"],
1010
deps = [
11+
"//pkg/sql/vecindex/vecpb:vecpb_proto",
1112
"//pkg/util/vector:vector_proto",
1213
"@com_github_gogo_protobuf//gogoproto:gogo_proto",
1314
],
@@ -20,6 +21,7 @@ go_proto_library(
2021
proto = ":quantize_proto",
2122
visibility = ["//visibility:public"],
2223
deps = [
24+
"//pkg/sql/vecindex/vecpb",
2325
"//pkg/util/vector",
2426
"@com_github_gogo_protobuf//gogoproto",
2527
],

pkg/sql/vecindex/cspann/quantize/quantize.proto

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ syntax = "proto3";
77
package cockroach.sql.vecindex.quantize;
88
option go_package = "github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/quantize";
99

10+
import "sql/vecindex/vecpb/vec.proto";
1011
import "util/vector/vector.proto";
1112
import "gogoproto/gogo.proto";
1213

@@ -32,34 +33,37 @@ message RaBitQCodeSet {
3233
message RaBitQuantizedVectorSet {
3334
option (gogoproto.equal) = true;
3435

36+
// Metric specifies the metric used to compute similarity between vectors in
37+
// the set.
38+
cockroach.sql.vecindex.vecpb.DistanceMetric metric = 1;
3539
// Centroid is the average of vectors in the set, representing its "center of
3640
// mass". Note that the centroid is computed when a vector set is created and
3741
// is not updated when vectors are added or removed.
3842
// NOTE: This is always the mean centroid, regardless of the distance metric.
3943
// The caller is responsible for converting this to a spherical centroid when
4044
// that's needed.
4145
// NOTE: The centroid should be treated as immutable.
42-
repeated float centroid = 1;
46+
repeated float centroid = 2;
4347
// Codes is a set of RaBitQ quantization codes, with one code per quantized
4448
// vector in the set.
45-
RaBitQCodeSet codes = 2 [(gogoproto.nullable) = false];
49+
RaBitQCodeSet codes = 3 [(gogoproto.nullable) = false];
4650
// CodeCounts records the count of "1" bits in each of the quantization codes.
47-
repeated uint32 code_counts = 3;
51+
repeated uint32 code_counts = 4;
4852
// CentroidDistances is a slice of the exact Euclidean distances (non-squared)
4953
// of the original full-size vectors from the centroid.
50-
repeated float centroid_distances = 4;
54+
repeated float centroid_distances = 5;
5155
// QuantizedDotProducts is a slice of the exact inner products between the
5256
// original full-size vectors and their corresponding quantized vectors.
5357
// NOTE: These values have been inverted (1/inner_product) to avoid expensive
5458
// division during distance estimation.
55-
repeated float quantized_dot_products = 5;
59+
repeated float quantized_dot_products = 6;
5660
// CentroidDotProducts is a slice of the exact inner products between the
5761
// original full-size vectors and the centroid.
5862
// NOTE: This is always nil when using the L2Squared distance metric.
59-
repeated float centroid_dot_products = 6;
63+
repeated float centroid_dot_products = 7;
6064
// CentroidNorm is the L2 norm of the mean centroid.
6165
// NOTE: This is always nil when using the L2Squared distance metric.
62-
float centroid_norm = 7;
66+
float centroid_norm = 8;
6367
}
6468

6569
// UnQuantizedVectorSet trivially implements the QuantizedVectorSet interface,

pkg/sql/vecindex/cspann/quantize/rabitq.go

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,37 +126,41 @@ func (q *RaBitQuantizer) QuantizeInSet(
126126

127127
// NewSet implements the Quantizer interface
128128
func (q *RaBitQuantizer) NewSet(capacity int, centroid vector.T) QuantizedVectorSet {
129-
codeWidth := RaBitQCodeSetWidth(q.GetDims())
130-
dataBuffer := make([]uint64, 0, capacity*codeWidth)
129+
var vs *RaBitQuantizedVectorSet
130+
131131
if capacity <= 1 {
132132
// Special case capacity of zero or one by using in-line storage.
133-
var quantized raBitQuantizedVector
134-
quantized.Centroid = centroid
135-
quantized.Codes = MakeRaBitQCodeSetFromRawData(dataBuffer, codeWidth)
133+
quantized := &raBitQuantizedVector{}
136134
quantized.CodeCounts = quantized.codeCountStorage[:0]
137135
quantized.CentroidDistances = quantized.centroidDistanceStorage[:0]
138136
quantized.QuantizedDotProducts = quantized.quantizedDotProductStorage[:0]
139137

140-
// L2Squared doesn't use these, so don't make extra calculations.
138+
// L2Squared doesn't use this.
141139
if q.distanceMetric != vecpb.L2SquaredDistance {
142140
quantized.CentroidDotProducts = quantized.centroidDotProductStorage[:0]
143-
quantized.CentroidNorm = num32.Norm(centroid)
144141
}
145-
return &quantized.RaBitQuantizedVectorSet
142+
vs = &quantized.RaBitQuantizedVectorSet
143+
} else {
144+
vs = &RaBitQuantizedVectorSet{
145+
CodeCounts: make([]uint32, 0, capacity),
146+
CentroidDistances: make([]float32, 0, capacity),
147+
QuantizedDotProducts: make([]float32, 0, capacity),
148+
}
149+
// L2Squared doesn't use these, so don't make extra allocation or calculation.
150+
if q.distanceMetric != vecpb.L2SquaredDistance {
151+
vs.CentroidDotProducts = make([]float32, 0, capacity)
152+
}
146153
}
147154

148-
vs := &RaBitQuantizedVectorSet{
149-
Centroid: centroid,
150-
Codes: MakeRaBitQCodeSetFromRawData(dataBuffer, codeWidth),
151-
CodeCounts: make([]uint32, 0, capacity),
152-
CentroidDistances: make([]float32, 0, capacity),
153-
QuantizedDotProducts: make([]float32, 0, capacity),
154-
}
155-
// L2Squared doesn't use these, so don't make extra allocation or calculation.
155+
vs.Metric = q.distanceMetric
156+
vs.Centroid = centroid
157+
codeWidth := RaBitQCodeSetWidth(q.GetDims())
158+
dataBuffer := make([]uint64, 0, capacity*codeWidth)
159+
vs.Codes = MakeRaBitQCodeSetFromRawData(dataBuffer, codeWidth)
156160
if q.distanceMetric != vecpb.L2SquaredDistance {
157-
vs.CentroidDotProducts = make([]float32, 0, capacity)
158161
vs.CentroidNorm = num32.Norm(centroid)
159162
}
163+
160164
return vs
161165
}
162166

@@ -412,7 +416,7 @@ func (q *RaBitQuantizer) quantizeHelper(
412416
// Extend any existing slices in the vector set.
413417
count := vectors.Count
414418
oldCount := qs.GetCount()
415-
qs.AddUndefined(count, q.distanceMetric)
419+
qs.AddUndefined(count)
416420

417421
// L2Squared doesn't use this, so don't store it.
418422
if q.distanceMetric != vecpb.L2SquaredDistance {

pkg/sql/vecindex/cspann/quantize/rabitqpb.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ func (vs *RaBitQuantizedVectorSet) ReplaceWithLast(offset int) {
135135
// Clone implements the QuantizedVectorSet interface.
136136
func (vs *RaBitQuantizedVectorSet) Clone() QuantizedVectorSet {
137137
return &RaBitQuantizedVectorSet{
138+
Metric: vs.Metric,
138139
Centroid: vs.Centroid, // Centroid is immutable
139140
Codes: vs.Codes.Clone(),
140141
CodeCounts: slices.Clone(vs.CodeCounts),
@@ -167,14 +168,16 @@ func (vs *RaBitQuantizedVectorSet) Clear(centroid vector.T) {
167168
vs.CentroidDistances = vs.CentroidDistances[:0]
168169
vs.QuantizedDotProducts = vs.QuantizedDotProducts[:0]
169170
vs.CentroidDotProducts = vs.CentroidDotProducts[:0]
170-
if &vs.Centroid[0] != &centroid[0] {
171-
vs.CentroidNorm = num32.Norm(centroid)
171+
if vs.Metric != vecpb.L2SquaredDistance {
172+
if &vs.Centroid[0] != &centroid[0] {
173+
vs.CentroidNorm = num32.Norm(centroid)
174+
}
172175
}
173176
}
174177

175178
// AddUndefined adds the given number of quantized vectors to this set. The new
176179
// quantized vector information should be set to defined values before use.
177-
func (vs *RaBitQuantizedVectorSet) AddUndefined(count int, distanceMetric vecpb.DistanceMetric) {
180+
func (vs *RaBitQuantizedVectorSet) AddUndefined(count int) {
178181
newCount := len(vs.CodeCounts) + count
179182
vs.Codes.AddUndefined(count)
180183
vs.CodeCounts = slices.Grow(vs.CodeCounts, count)
@@ -183,7 +186,7 @@ func (vs *RaBitQuantizedVectorSet) AddUndefined(count int, distanceMetric vecpb.
183186
vs.CentroidDistances = vs.CentroidDistances[:newCount]
184187
vs.QuantizedDotProducts = slices.Grow(vs.QuantizedDotProducts, count)
185188
vs.QuantizedDotProducts = vs.QuantizedDotProducts[:newCount]
186-
if distanceMetric != vecpb.L2SquaredDistance {
189+
if vs.Metric != vecpb.L2SquaredDistance {
187190
// L2Squared doesn't need this.
188191
vs.CentroidDotProducts = slices.Grow(vs.CentroidDotProducts, count)
189192
vs.CentroidDotProducts = vs.CentroidDotProducts[:newCount]

pkg/sql/vecindex/cspann/quantize/rabitqpb_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,11 @@ func TestRaBitCodeSet(t *testing.T) {
6262

6363
func TestRaBitQuantizedVectorSet(t *testing.T) {
6464
var quantizedSet RaBitQuantizedVectorSet
65+
quantizedSet.Metric = vecpb.L2SquaredDistance
6566
quantizedSet.Centroid = []float32{1, 2, 3}
6667
quantizedSet.Codes.Width = 3
6768

68-
quantizedSet.AddUndefined(5, vecpb.L2SquaredDistance)
69+
quantizedSet.AddUndefined(5)
6970
copy(quantizedSet.Codes.At(4), []uint64{1, 2, 3})
7071
quantizedSet.CodeCounts[4] = 15
7172
quantizedSet.CentroidDistances[4] = 1.23
@@ -107,7 +108,8 @@ func TestRaBitQuantizedVectorSet(t *testing.T) {
107108
// Test InnerProduct distance metric, which uses the CentroidDotProducts
108109
// field (L2Squared does not use it).
109110
quantizedSet.Clear(quantizedSet.Centroid)
110-
quantizedSet.AddUndefined(2, vecpb.InnerProductDistance)
111+
quantizedSet.Metric = vecpb.InnerProductDistance
112+
quantizedSet.AddUndefined(2)
111113
copy(quantizedSet.Codes.At(1), []uint64{1, 2, 3})
112114
quantizedSet.CodeCounts[1] = 15
113115
quantizedSet.CentroidDistances[1] = 1.23

pkg/sql/vecindex/vecencoding/encoding.go

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -202,22 +202,16 @@ func EncodeMetadataValue(metadata cspann.PartitionMetadata) []byte {
202202
}
203203

204204
// EncodeRaBitQVector encodes a RaBitQ vector into the given byte slice.
205-
func EncodeRaBitQVector(
206-
appendTo []byte,
207-
codeCount uint32,
208-
centroidDistance float32,
209-
quantizedDotProduct float32,
210-
centroidDotProduct float32,
211-
code quantize.RaBitQCode,
212-
metric vecpb.DistanceMetric,
205+
func EncodeRaBitQVectorFromSet(
206+
appendTo []byte, vectorSet *quantize.RaBitQuantizedVectorSet, offset int,
213207
) []byte {
214-
appendTo = encoding.EncodeUint32Ascending(appendTo, codeCount)
215-
appendTo = encoding.EncodeUntaggedFloat32Value(appendTo, centroidDistance)
216-
appendTo = encoding.EncodeUntaggedFloat32Value(appendTo, quantizedDotProduct)
217-
if metric != vecpb.L2SquaredDistance {
218-
appendTo = encoding.EncodeUntaggedFloat32Value(appendTo, centroidDotProduct)
208+
appendTo = encoding.EncodeUint32Ascending(appendTo, vectorSet.CodeCounts[offset])
209+
appendTo = encoding.EncodeUntaggedFloat32Value(appendTo, vectorSet.CentroidDistances[offset])
210+
appendTo = encoding.EncodeUntaggedFloat32Value(appendTo, vectorSet.QuantizedDotProducts[offset])
211+
if vectorSet.Metric != vecpb.L2SquaredDistance {
212+
appendTo = encoding.EncodeUntaggedFloat32Value(appendTo, vectorSet.CentroidDotProducts[offset])
219213
}
220-
for _, c := range code {
214+
for _, c := range vectorSet.Codes.At(offset) {
221215
appendTo = encoding.EncodeUint64Ascending(appendTo, c)
222216
}
223217
return appendTo
@@ -318,7 +312,7 @@ func DecodeMetadataValue(encMetadata []byte) (metadata cspann.PartitionMetadata,
318312
// RaBitQuantizedVectorSet. The vector set must have been initialized with the
319313
// correct number of dimensions. It returns the remainder of the input buffer.
320314
func DecodeRaBitQVectorToSet(
321-
encVector []byte, vectorSet *quantize.RaBitQuantizedVectorSet, metric vecpb.DistanceMetric,
315+
encVector []byte, vectorSet *quantize.RaBitQuantizedVectorSet,
322316
) ([]byte, error) {
323317
encVector, codeCount, err := encoding.DecodeUint32Ascending(encVector)
324318
if err != nil {
@@ -332,7 +326,7 @@ func DecodeRaBitQVectorToSet(
332326
if err != nil {
333327
return nil, err
334328
}
335-
if metric != vecpb.L2SquaredDistance {
329+
if vectorSet.Metric != vecpb.L2SquaredDistance {
336330
var centroidDotProduct float32
337331
encVector, centroidDotProduct, err = encoding.DecodeUntaggedFloat32Value(encVector)
338332
if err != nil {

pkg/sql/vecindex/vecencoding/encoding_test.go

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,7 @@ func testEncodeDecodeRoundTripImpl(t *testing.T, rnd *rand.Rand, set vector.Set)
9494
buf, err = vecencoding.EncodeUnquantizerVector(buf, set.At(i))
9595
require.NoError(t, err)
9696
case *quantize.RaBitQuantizedVectorSet:
97-
var centroidDotProduct float32
98-
if distMetric != vecpb.L2SquaredDistance {
99-
centroidDotProduct = quantizedSet.CentroidDotProducts[i]
100-
}
101-
buf = vecencoding.EncodeRaBitQVector(buf,
102-
quantizedSet.CodeCounts[i], quantizedSet.CentroidDistances[i],
103-
quantizedSet.QuantizedDotProducts[i], centroidDotProduct,
104-
quantizedSet.Codes.At(i), distMetric,
105-
)
97+
buf = vecencoding.EncodeRaBitQVectorFromSet(buf, quantizedSet, i)
10698
}
10799
}
108100

@@ -131,9 +123,7 @@ func testEncodeDecodeRoundTripImpl(t *testing.T, rnd *rand.Rand, set vector.Set)
131123
decodedSet = quantizer.NewSet(set.Count, decodedMetadata.Centroid)
132124
for range set.Count {
133125
remainder, err = vecencoding.DecodeRaBitQVectorToSet(
134-
remainder,
135-
decodedSet.(*quantize.RaBitQuantizedVectorSet),
136-
distMetric,
126+
remainder, decodedSet.(*quantize.RaBitQuantizedVectorSet),
137127
)
138128
require.NoError(t, err)
139129
}
@@ -188,6 +178,7 @@ func testingAssertPartitionsEqual(t *testing.T, l, r *cspann.Partition) {
188178
case *quantize.RaBitQuantizedVectorSet:
189179
rightSet, ok := q2.(*quantize.RaBitQuantizedVectorSet)
190180
require.True(t, ok, "quantized set types do not match")
181+
require.Equal(t, leftSet.Metric, rightSet.Metric)
191182
require.Equal(t, leftSet.CodeCounts, rightSet.CodeCounts, "code counts do not match")
192183
require.Equal(t, leftSet.Codes, rightSet.Codes, "codes do not match")
193184
require.Equal(t, leftSet.QuantizedDotProducts, rightSet.QuantizedDotProducts,

pkg/sql/vecindex/vecstore/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ go_library(
2828
"//pkg/sql/vecindex/cspann/quantize",
2929
"//pkg/sql/vecindex/cspann/workspace",
3030
"//pkg/sql/vecindex/vecencoding",
31-
"//pkg/sql/vecindex/vecpb",
3231
"//pkg/util/log",
3332
"//pkg/util/unique",
3433
"//pkg/util/vector",

pkg/sql/vecindex/vecstore/codec.go

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/quantize"
1111
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/workspace"
1212
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecencoding"
13-
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
1413
"github.com/cockroachdb/cockroach/pkg/util/vector"
1514
"github.com/cockroachdb/errors"
1615
)
@@ -56,9 +55,7 @@ func (sc *storeCodec) DecodeVector(encodedVector []byte) ([]byte, error) {
5655
encodedVector, sc.tmpVectorSet.(*quantize.UnQuantizedVectorSet))
5756
case *quantize.RaBitQuantizer:
5857
return vecencoding.DecodeRaBitQVectorToSet(
59-
encodedVector,
60-
sc.tmpVectorSet.(*quantize.RaBitQuantizedVectorSet),
61-
sc.quantizer.GetDistanceMetric(),
58+
encodedVector, sc.tmpVectorSet.(*quantize.RaBitQuantizedVectorSet),
6259
)
6360
}
6461
return nil, errors.Errorf("unknown quantizer type %T", sc.quantizer)
@@ -75,21 +72,7 @@ func (sc *storeCodec) EncodeVector(w *workspace.T, v vector.T, centroid vector.T
7572
case *quantize.UnQuantizedVectorSet:
7673
return vecencoding.EncodeUnquantizerVector([]byte{}, t.Vectors.At(0))
7774
case *quantize.RaBitQuantizedVectorSet:
78-
metric := sc.quantizer.GetDistanceMetric()
79-
var centroidDotProduct float32
80-
if metric != vecpb.L2SquaredDistance {
81-
// CentroidDotProducts is only defined for non-L2 distance metrics.
82-
centroidDotProduct = t.CentroidDotProducts[0]
83-
}
84-
return vecencoding.EncodeRaBitQVector(
85-
[]byte{},
86-
t.CodeCounts[0],
87-
t.CentroidDistances[0],
88-
t.QuantizedDotProducts[0],
89-
centroidDotProduct,
90-
t.Codes.At(0),
91-
metric,
92-
), nil
75+
return vecencoding.EncodeRaBitQVectorFromSet([]byte{}, t, 0), nil
9376
default:
9477
return nil, errors.Errorf("unknown quantizer type %T", t)
9578
}

0 commit comments

Comments
 (0)