Skip to content

Commit f9e3388

Browse files
committed
cspann: fix bug in CalculateTruth
Fix bug in CalculateTruth, which was calling vecpb.MeasureDistance with un-normalized vectors for Cosine distance. This returns incorrect results. Epic: CRDB-42943 Release note: None
1 parent 2354426 commit f9e3388

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

pkg/sql/vecindex/cspann/testdata/search-embeddings.ddt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,20 +395,20 @@ CV stats:
395395

396396
recall topk=10 beam-size=4 samples=50
397397
----
398-
20.80% recall@10
398+
30.00% recall@10
399399
45 leaf vectors, 74 vectors, 31 full vectors, 7 partitions
400400

401401
recall topk=10 beam-size=8 samples=50
402402
----
403-
28.40% recall@10
403+
43.40% recall@10
404404
88 leaf vectors, 139 vectors, 47 full vectors, 13 partitions
405405

406406
recall topk=10 beam-size=16 samples=50
407407
----
408-
34.60% recall@10
408+
62.80% recall@10
409409
175 leaf vectors, 265 vectors, 66 full vectors, 25 partitions
410410

411411
recall topk=10 beam-size=32 samples=50
412412
----
413-
38.60% recall@10
413+
81.80% recall@10
414414
348 leaf vectors, 447 vectors, 89 full vectors, 42 partitions

pkg/sql/vecindex/cspann/testutils/testutils.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,22 @@ func CalculateTruth[T comparable](
117117
dataVectors vector.Set,
118118
dataKeys []T,
119119
) []T {
120+
var queryNorm float32
121+
if distMetric == vecpb.CosineDistance {
122+
// MeasureDistance assumes input vectors are normalized.
123+
queryNorm = num32.Norm(queryVector)
124+
}
120125
distances := make([]float32, dataVectors.Count)
121126
offsets := make([]int, dataVectors.Count)
122127
for i := range dataVectors.Count {
123-
distances[i] = vecpb.MeasureDistance(distMetric, queryVector, dataVectors.At(i))
128+
data := dataVectors.At(i)
129+
if distMetric == vecpb.CosineDistance {
130+
// MeasureDistance assumes input vectors are normalized, so adjust the
131+
// result.
132+
distances[i] = 1 - num32.Dot(queryVector, data)/(queryNorm*num32.Norm(data))
133+
} else {
134+
distances[i] = vecpb.MeasureDistance(distMetric, queryVector, data)
135+
}
124136
offsets[i] = i
125137
}
126138
sort.SliceStable(offsets, func(i int, j int) bool {

0 commit comments

Comments
 (0)