Skip to content

Commit 56287cc

Browse files
committed
cspann: add BestCentroids vector index test
Add a new "best-centroids" test that prints out the partitions with the closest centroids for a query vector. This is useful when gauging the quality of the index. Also, update the "recall" test to only sample from vectors that are not part of the index. Achieving high recall is more challenging when searching for such vectors. Epic: CRDB-42943 Release note: None
1 parent 426de06 commit 56287cc

File tree

4 files changed

+153
-26
lines changed

4 files changed

+153
-26
lines changed

pkg/sql/vecindex/cspann/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ go_test(
8989
"@com_github_cockroachdb_errors//:errors",
9090
"@com_github_guptarohit_asciigraph//:asciigraph",
9191
"@com_github_stretchr_testify//require",
92+
"@org_golang_x_exp//slices",
9293
"@org_gonum_v1_gonum//floats/scalar",
9394
"@org_gonum_v1_gonum//stat",
9495
],

pkg/sql/vecindex/cspann/index_test.go

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
"github.com/cockroachdb/datadriven"
3535
"github.com/cockroachdb/errors"
3636
"github.com/stretchr/testify/require"
37+
"golang.org/x/exp/slices"
3738
)
3839

3940
func TestIndex(t *testing.T) {
@@ -108,6 +109,9 @@ func TestIndex(t *testing.T) {
108109
case "recall":
109110
result = state.Recall(d)
110111

112+
case "best-centroids":
113+
result = state.BestCentroids(d)
114+
111115
case "validate-tree":
112116
result = state.ValidateTree(d)
113117

@@ -503,7 +507,7 @@ func (s *testState) Recall(d *datadriven.TestData) string {
503507
rng := rand.New(rand.NewSource(int64(seed)))
504508
remaining := make([]int, s.Dataset.Count-len(data))
505509
for i := range remaining {
506-
remaining[i] = i
510+
remaining[i] = len(data) + i
507511
}
508512
rng.Shuffle(len(remaining), func(i, j int) {
509513
remaining[i], remaining[j] = remaining[j], remaining[i]
@@ -522,6 +526,7 @@ func (s *testState) Recall(d *datadriven.TestData) string {
522526
for i := range samples {
523527
// Calculate truth set for the vector.
524528
queryVector := s.Dataset.At(samples[i])
529+
525530
truth := testutils.CalculateTruth(searchSet.MaxResults,
526531
s.Quantizer.GetDistanceMetric(), queryVector, dataVectors, dataKeys)
527532

@@ -554,6 +559,90 @@ func (s *testState) Recall(d *datadriven.TestData) string {
554559
return buf.String()
555560
}
556561

562+
func (s *testState) BestCentroids(d *datadriven.TestData) string {
563+
randomized := make(vector.T, s.Dataset.Dims)
564+
topk := 10
565+
for _, arg := range d.CmdArgs {
566+
switch arg.Key {
567+
case "use-dataset":
568+
original := s.parseUseDataset(arg)
569+
s.Index.TransformVector(original, randomized)
570+
571+
case "topk":
572+
topk = s.parseInt(arg)
573+
}
574+
}
575+
576+
var w workspace.T
577+
var distances, errorBounds []float32
578+
var partitionKeys []cspann.PartitionKey
579+
580+
var findCentroids func(partitionKey cspann.PartitionKey)
581+
findCentroids = func(partitionKey cspann.PartitionKey) {
582+
partition, err := s.MemStore.TryGetPartition(s.Ctx, s.TreeKey, partitionKey)
583+
require.NoError(s.T, err)
584+
count := partition.Count()
585+
586+
switch partition.Level() {
587+
case cspann.LeafLevel:
588+
// Nothing to do.
589+
590+
case cspann.SecondLevel:
591+
distances = slices.Grow(distances, count)
592+
distances = distances[:len(distances)+count]
593+
errorBounds = slices.Grow(errorBounds, count)
594+
errorBounds = errorBounds[:len(errorBounds)+count]
595+
596+
partition.Quantizer().EstimateDistances(&w, partition.QuantizedSet(), randomized,
597+
distances[len(distances)-count:],
598+
errorBounds[len(errorBounds)-count:])
599+
600+
for _, key := range partition.ChildKeys() {
601+
partitionKeys = append(partitionKeys, key.PartitionKey)
602+
}
603+
604+
default:
605+
// Descend to next level.
606+
for _, key := range partition.ChildKeys() {
607+
findCentroids(key.PartitionKey)
608+
}
609+
}
610+
}
611+
612+
findCentroids(cspann.RootKey)
613+
614+
// Create offsets for argsort.
615+
offsets := make([]int, len(partitionKeys))
616+
for i := range offsets {
617+
offsets[i] = i
618+
}
619+
620+
// Sort indices by distance (argsort).
621+
slices.SortFunc(offsets, func(a, b int) int {
622+
if distances[a] < distances[b] {
623+
return -1
624+
} else if distances[a] > distances[b] {
625+
return 1
626+
}
627+
return 0
628+
})
629+
630+
// Print top results.
631+
var buf strings.Builder
632+
for i := range min(topk, len(offsets)) {
633+
offset := offsets[i]
634+
635+
partition, err := s.MemStore.TryGetPartition(s.Ctx, s.TreeKey, partitionKeys[offset])
636+
require.NoError(s.T, err)
637+
exact := vecpb.MeasureDistance(vecpb.L2SquaredDistance, randomized, partition.Centroid())
638+
639+
fmt.Fprintf(&buf, "%d: %.4f ± %.4f (exact=%.4f)\n",
640+
partitionKeys[offset], distances[offset], errorBounds[offset], exact)
641+
}
642+
643+
return buf.String()
644+
}
645+
557646
func (s *testState) ValidateTree(d *datadriven.TestData) string {
558647
vectorCount := 0
559648
partitionKeys := []cspann.PartitionKey{cspann.RootKey}

pkg/sql/vecindex/cspann/testdata/search-embeddings.ddt

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -91,31 +91,65 @@ recall topk=20 use-dataset=2717 beam-size=8
9191
40.00% recall@20
9292
90 leaf vectors, 143 vectors, 42 full vectors, 13 partitions
9393

94+
# Show the nearest partitions to the "easy" vector, ordered by estimated
95+
# distance to their centroids. Notice that there are several partitions that are
96+
# very near, and yet the "spread" between centroids is fairly large, which makes
97+
# finding results easier.
98+
best-centroids topk=10 use-dataset=8601
99+
----
100+
151: 0.1696 ± 0.0098 (exact=0.1569)
101+
113: 0.2114 ± 0.0091 (exact=0.2164)
102+
150: 0.2365 ± 0.0089 (exact=0.2380)
103+
155: 0.2836 ± 0.0091 (exact=0.2778)
104+
154: 0.2943 ± 0.0108 (exact=0.2954)
105+
68: 0.2953 ± 0.0146 (exact=0.3056)
106+
97: 0.2988 ± 0.0097 (exact=0.3037)
107+
147: 0.2994 ± 0.0156 (exact=0.2853)
108+
159: 0.3001 ± 0.0120 (exact=0.2995)
109+
139: 0.3274 ± 0.0133 (exact=0.3368)
110+
111+
# Show the nearest partitions to the "hard" vector, ordered by estimated
112+
# distance to their centroids. Notice that the partitions are relatively far
113+
# away and are bunched together, with low "spread". This makes finding results
114+
# more difficult.
115+
best-centroids topk=10 use-dataset=2717
116+
----
117+
197: 0.5183 ± 0.0161 (exact=0.5179)
118+
166: 0.5361 ± 0.0223 (exact=0.5644)
119+
170: 0.5403 ± 0.0156 (exact=0.5453)
120+
30: 0.5524 ± 0.0197 (exact=0.5515)
121+
196: 0.5546 ± 0.0206 (exact=0.5621)
122+
187: 0.5646 ± 0.0171 (exact=0.5625)
123+
135: 0.5674 ± 0.0234 (exact=0.6034)
124+
177: 0.5708 ± 0.0254 (exact=0.5674)
125+
61: 0.5755 ± 0.0211 (exact=0.5581)
126+
183: 0.5777 ± 0.0159 (exact=0.5915)
127+
94128
# Test recall at different beam sizes.
95129
recall topk=10 beam-size=2 samples=64
96130
----
97-
34.22% recall@10
131+
29.84% recall@10
98132
21 leaf vectors, 42 vectors, 15 full vectors, 4 partitions
99133

100134
recall topk=10 beam-size=4 samples=64
101135
----
102-
50.31% recall@10
103-
42 leaf vectors, 73 vectors, 19 full vectors, 7 partitions
136+
47.97% recall@10
137+
42 leaf vectors, 74 vectors, 19 full vectors, 7 partitions
104138

105139
recall topk=10 beam-size=8 samples=64
106140
----
107-
73.75% recall@10
108-
84 leaf vectors, 137 vectors, 23 full vectors, 13 partitions
141+
69.06% recall@10
142+
85 leaf vectors, 138 vectors, 24 full vectors, 13 partitions
109143

110144
recall topk=10 beam-size=16 samples=64
111145
----
112-
87.81% recall@10
113-
168 leaf vectors, 262 vectors, 26 full vectors, 25 partitions
146+
87.66% recall@10
147+
168 leaf vectors, 263 vectors, 27 full vectors, 25 partitions
114148

115149
recall topk=10 beam-size=32 samples=64
116150
----
117-
97.50% recall@10
118-
335 leaf vectors, 441 vectors, 29 full vectors, 42 partitions
151+
95.62% recall@10
152+
336 leaf vectors, 442 vectors, 30 full vectors, 42 partitions
119153

120154
# ----------------------------------------------------------------------
121155
# Compare orderings of same dataset with different distance metrics.
@@ -255,23 +289,23 @@ CV stats:
255289

256290
recall topk=10 beam-size=4 samples=50
257291
----
258-
62.40% recall@10
259-
42 leaf vectors, 72 vectors, 18 full vectors, 7 partitions
292+
61.20% recall@10
293+
42 leaf vectors, 72 vectors, 20 full vectors, 7 partitions
260294

261295
recall topk=10 beam-size=8 samples=50
262296
----
263-
83.40% recall@10
297+
79.80% recall@10
264298
83 leaf vectors, 133 vectors, 21 full vectors, 13 partitions
265299

266300
recall topk=10 beam-size=16 samples=50
267301
----
268-
92.60% recall@10
269-
166 leaf vectors, 257 vectors, 24 full vectors, 25 partitions
302+
91.00% recall@10
303+
165 leaf vectors, 256 vectors, 24 full vectors, 25 partitions
270304

271305
recall topk=10 beam-size=32 samples=50
272306
----
273-
98.20% recall@10
274-
329 leaf vectors, 431 vectors, 25 full vectors, 42 partitions
307+
97.20% recall@10
308+
329 leaf vectors, 431 vectors, 26 full vectors, 42 partitions
275309

276310
# ----------------------------------------------------------------------
277311
# Load 950 768-dimension image embeddings and search them using
@@ -288,20 +322,20 @@ CV stats:
288322

289323
recall topk=10 beam-size=4 samples=50
290324
----
291-
55.80% recall@10
292-
44 leaf vectors, 74 vectors, 19 full vectors, 7 partitions
325+
48.60% recall@10
326+
44 leaf vectors, 76 vectors, 20 full vectors, 7 partitions
293327

294328
recall topk=10 beam-size=8 samples=50
295329
----
296-
74.40% recall@10
297-
88 leaf vectors, 143 vectors, 23 full vectors, 13 partitions
330+
69.00% recall@10
331+
88 leaf vectors, 144 vectors, 25 full vectors, 13 partitions
298332

299333
recall topk=10 beam-size=16 samples=50
300334
----
301-
89.00% recall@10
302-
172 leaf vectors, 271 vectors, 27 full vectors, 25 partitions
335+
85.00% recall@10
336+
173 leaf vectors, 272 vectors, 30 full vectors, 25 partitions
303337

304338
recall topk=10 beam-size=32 samples=50
305339
----
306-
97.60% recall@10
307-
344 leaf vectors, 443 vectors, 30 full vectors, 41 partitions
340+
95.20% recall@10
341+
342 leaf vectors, 441 vectors, 33 full vectors, 41 partitions

pkg/sql/vecindex/cspann/utils/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,8 @@ go_test(
2222
name = "utils_test",
2323
srcs = ["slice_test.go"],
2424
embed = [":utils"],
25-
deps = ["@com_github_stretchr_testify//require"],
25+
deps = [
26+
"//pkg/util/buildutil",
27+
"@com_github_stretchr_testify//require",
28+
],
2629
)

0 commit comments

Comments
 (0)