Skip to content

Commit c75a879

Browse files
committed
cspann: enhance SearchSet
Add support in SearchSet for excluding partitions during search. Any vectors in these partitions will not be added to the set. Also add an option that includes the distance of each vector from its centroid in search results. These options will be used by the split operation. Epic: CRDB-42943 Release note: None
1 parent 56287cc commit c75a879

File tree

4 files changed

+81
-17
lines changed

4 files changed

+81
-17
lines changed

pkg/sql/vecindex/cspann/partition.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func (p *Partition) Search(
154154
w *workspace.T, partitionKey PartitionKey, queryVector vector.T, searchSet *SearchSet,
155155
) int {
156156
count := p.Count()
157-
tempFloats := w.AllocFloats(count * 2)
157+
tempFloats := w.AllocFloats(count * 3)
158158
defer w.FreeFloats(tempFloats)
159159

160160
// Estimate distances of the data vectors from the query vector.
@@ -163,6 +163,11 @@ func (p *Partition) Search(
163163
p.quantizer.EstimateDistances(
164164
w, p.quantizedSet, queryVector, tempDistances, tempErrorBounds)
165165

166+
tempCentroidDistances := tempFloats[count*2 : count*3]
167+
if searchSet.IncludeCentroidDistances {
168+
p.quantizer.GetCentroidDistances(p.quantizedSet, tempCentroidDistances, true /* spherical */)
169+
}
170+
166171
// Add candidates to the search set, which is responsible for retaining the
167172
// top-k results.
168173
for i := range tempDistances {
@@ -173,6 +178,9 @@ func (p *Partition) Search(
173178
ChildKey: p.childKeys[i],
174179
ValueBytes: p.valueBytes[i],
175180
}
181+
if searchSet.IncludeCentroidDistances {
182+
searchSet.tempResult.CentroidDistance = tempCentroidDistances[i]
183+
}
176184
searchSet.Add(&searchSet.tempResult)
177185
}
178186

pkg/sql/vecindex/cspann/partition_test.go

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ func TestPartition(t *testing.T) {
4444
valueBytes20b := ValueBytes{11, 12}
4545

4646
var workspace workspace.T
47-
quantizer := quantize.NewUnQuantizer(2, vecpb.L2SquaredDistance)
47+
unquantizer := quantize.NewUnQuantizer(2, vecpb.L2SquaredDistance)
48+
rabitq := quantize.NewRaBitQuantizer(2, 42, vecpb.InnerProductDistance)
4849

4950
// newTestPartition creates a partition with 2 vectors.
50-
newTestPartition := func() *Partition {
51+
newTestPartition := func(quantizer quantize.Quantizer) *Partition {
5152
vectors := vector.MakeSet(2)
5253
vectors.Add(vec10)
5354
vectors.Add(vec20)
@@ -69,7 +70,7 @@ func TestPartition(t *testing.T) {
6970

7071
t.Run("test Init", func(t *testing.T) {
7172
// Validate that Init sets same values.
72-
partition := newTestPartition()
73+
partition := newTestPartition(unquantizer)
7374
var partition2 Partition
7475
partition2.Init(
7576
*partition.Metadata(),
@@ -82,7 +83,7 @@ func TestPartition(t *testing.T) {
8283
})
8384

8485
t.Run("test Clone", func(t *testing.T) {
85-
partition := newTestPartition()
86+
partition := newTestPartition(unquantizer)
8687
cloned := partition.Clone()
8788
require.Equal(t, partition, cloned)
8889

@@ -101,7 +102,7 @@ func TestPartition(t *testing.T) {
101102
})
102103

103104
t.Run("test Add", func(t *testing.T) {
104-
partition := newTestPartition()
105+
partition := newTestPartition(unquantizer)
105106
require.True(t, partition.Add(&workspace, vec40, childKey40, valueBytes40, true /* overwrite */))
106107
require.Equal(t, 4, partition.Count())
107108
require.Equal(t, []ChildKey{childKey10, childKey20, childKey30, childKey40}, partition.ChildKeys())
@@ -111,14 +112,14 @@ func TestPartition(t *testing.T) {
111112
checkPartitionMetadata(t, partition.Metadata(), Level(1), vector.T{4, 3.33})
112113

113114
// Add vector with duplicate key and overwrite=false. Expect no-op.
114-
partition = newTestPartition()
115+
partition = newTestPartition(unquantizer)
115116
require.False(t, partition.Add(&workspace, vec20b, childKey20, valueBytes20b, false /* overwrite */))
116117
require.Equal(t, 3, partition.Count())
117118
require.Equal(t, []ValueBytes{valueBytes10, valueBytes20, valueBytes30}, partition.ValueBytes())
118119

119120
// Add vector with duplicate key and overwrite=true. Expect value to be
120121
// updated.
121-
partition = newTestPartition()
122+
partition = newTestPartition(unquantizer)
122123
require.False(t, partition.Add(&workspace, vec20b, childKey20, valueBytes20b, true /* overwrite */))
123124
require.Equal(t, 3, partition.Count())
124125
require.Equal(t, []ChildKey{childKey10, childKey30, childKey20}, partition.ChildKeys())
@@ -129,7 +130,7 @@ func TestPartition(t *testing.T) {
129130
t.Run("test AddSet", func(t *testing.T) {
130131
// Create empty partition.
131132
metadata := PartitionMetadata{Level: 1, Centroid: vector.T{4, 3}}
132-
partition := CreateEmptyPartition(quantizer, metadata)
133+
partition := CreateEmptyPartition(unquantizer, metadata)
133134

134135
// Add empty set.
135136
vectors := vector.MakeSet(2)
@@ -188,7 +189,7 @@ func TestPartition(t *testing.T) {
188189
t.Run("test Search", func(t *testing.T) {
189190
// Search empty partition.
190191
metadata := PartitionMetadata{Level: LeafLevel, Centroid: vector.T{4, 3}}
191-
partition := CreateEmptyPartition(quantizer, metadata)
192+
partition := CreateEmptyPartition(unquantizer, metadata)
192193
require.Equal(t, Level(1), partition.Level())
193194

194195
searchSet := SearchSet{MaxResults: 1}
@@ -198,7 +199,7 @@ func TestPartition(t *testing.T) {
198199
require.Equal(t, SearchResults(nil), results)
199200

200201
// Search partition with 5 vectors.
201-
partition = newTestPartition()
202+
partition = newTestPartition(unquantizer)
202203
vectors := vector.MakeSet(2)
203204
vectors.Add(vec40)
204205
vectors.Add(vec50)
@@ -219,8 +220,24 @@ func TestPartition(t *testing.T) {
219220
require.Equal(t, SearchResults{result1, result2, result3}, results)
220221
})
221222

223+
t.Run("test Search with IncludeCentroidDistances", func(t *testing.T) {
224+
// Search partition with 3 vectors.
225+
partition := newTestPartition(rabitq)
226+
227+
searchSet := SearchSet{MaxResults: 2, IncludeCentroidDistances: true}
228+
_ = partition.Search(&workspace, RootKey, vector.T{1, 1}, &searchSet)
229+
result1 := SearchResult{
230+
QueryDistance: -11.52, ErrorBound: 8.96, CentroidDistance: -8.45, ParentPartitionKey: 1,
231+
ChildKey: childKey30, ValueBytes: valueBytes30}
232+
result2 := SearchResult{
233+
QueryDistance: -6.1, ErrorBound: 4.48, CentroidDistance: -5.12, ParentPartitionKey: 1,
234+
ChildKey: childKey20, ValueBytes: valueBytes20}
235+
results := roundResults(searchSet.PopResults(), 2)
236+
require.Equal(t, SearchResults{result1, result2}, results)
237+
})
238+
222239
t.Run("test ReplaceWithLast", func(t *testing.T) {
223-
partition := newTestPartition()
240+
partition := newTestPartition(unquantizer)
224241
partition.ReplaceWithLast(0)
225242
require.Equal(t, 2, partition.Count())
226243
require.Equal(t, []ChildKey{childKey30, childKey20}, partition.ChildKeys())
@@ -237,7 +254,7 @@ func TestPartition(t *testing.T) {
237254
})
238255

239256
t.Run("test ReplaceWithLastByKey and Find", func(t *testing.T) {
240-
partition := newTestPartition()
257+
partition := newTestPartition(unquantizer)
241258
require.Equal(t, 0, partition.Find(childKey10))
242259
require.Equal(t, 2, partition.Find(childKey30))
243260
require.True(t, partition.ReplaceWithLastByKey(childKey10))
@@ -252,7 +269,7 @@ func TestPartition(t *testing.T) {
252269
})
253270

254271
t.Run("test Clear", func(t *testing.T) {
255-
partition := newTestPartition()
272+
partition := newTestPartition(unquantizer)
256273
require.Equal(t, 3, partition.Clear())
257274
require.Equal(t, 0, partition.Count())
258275
require.Equal(t, []ChildKey{}, partition.ChildKeys())
@@ -273,6 +290,7 @@ func roundResults(results SearchResults, prec int) SearchResults {
273290
result := &results[i]
274291
result.QueryDistance = float32(scalar.Round(float64(result.QueryDistance), prec))
275292
result.ErrorBound = float32(scalar.Round(float64(result.ErrorBound), prec))
293+
result.CentroidDistance = float32(scalar.Round(float64(result.CentroidDistance), prec))
276294
result.Vector = testutils.RoundFloats(result.Vector, prec)
277295
}
278296
return results

pkg/sql/vecindex/cspann/search_set.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ type SearchResult struct {
2626
// ErrorBound captures the uncertainty of the distance estimate, which is
2727
// highly likely to fall within QueryDistance ± ErrorBound.
2828
ErrorBound float32
29+
// CentroidDistance is the exact distance of the result vector from its
30+
// centroid, according to the distance metric in use (e.g. L2Squared or
31+
// Cosine).
32+
// NOTE: This is only returned if the SearchSet.IncludeCentroidDistances is
33+
// set to true.
34+
CentroidDistance float32
2935
// ParentPartitionKey is the key of the parent of the partition that contains
3036
// the data vector.
3137
ParentPartitionKey PartitionKey
@@ -182,6 +188,15 @@ type SearchSet struct {
182188
// matching primary key.
183189
MatchKey KeyBytes
184190

191+
// ExcludedPartitions specifies which partitions to skip during search.
192+
// Vectors in any of these partitions will not be added to the set.
193+
ExcludedPartitions []PartitionKey
194+
195+
// IncludeCentroidDistances indicates that search results need to have their
196+
// CentroidDistance field set. This records the vector's distance from the
197+
// centroid of its partition.
198+
IncludeCentroidDistances bool
199+
185200
// Stats tracks useful information about the search, such as how many vectors
186201
// and partitions were scanned.
187202
Stats SearchStats
@@ -210,7 +225,7 @@ type SearchSet struct {
210225
func (ss *SearchSet) Init() {
211226
ss.deDuper.Clear()
212227
*ss = SearchSet{
213-
deDuper: ss.deDuper,
228+
deDuper: ss.deDuper,
214229
candidates: ss.candidates[:0],
215230
}
216231
}
@@ -222,7 +237,8 @@ func (ss *SearchSet) Count() int {
222237
return len(ss.candidates)
223238
}
224239

225-
// Clear removes all candidates from the set.
240+
// Clear removes all candidates from the set, but does not otherwise disturb
241+
// other settings.
226242
func (ss *SearchSet) Clear() {
227243
ss.candidates = ss.candidates[:0]
228244
ss.deDuper.Clear()
@@ -254,6 +270,13 @@ func (ss *SearchSet) Add(candidate *SearchResult) {
254270
return
255271
}
256272

273+
// Skip vectors in excluded partitions.
274+
if ss.ExcludedPartitions != nil {
275+
if slices.Contains(ss.ExcludedPartitions, candidate.ParentPartitionKey) {
276+
return
277+
}
278+
}
279+
257280
if ss.candidates == nil {
258281
// Pre-allocate some capacity for candidates.
259282
ss.candidates = make(searchResultHeap, 0, 16)
@@ -277,7 +300,7 @@ func (ss *SearchSet) AddSet(searchSet *SearchSet) {
277300
return
278301
}
279302
ss.candidates = slices.Grow(ss.candidates, len(searchSet.candidates))
280-
if ss.MatchKey != nil {
303+
if ss.MatchKey != nil || ss.ExcludedPartitions != nil {
281304
// Add each candidate individually in order to check the match key.
282305
ss.AddAll(SearchResults(searchSet.candidates))
283306
} else {

pkg/sql/vecindex/cspann/search_set_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,4 +266,19 @@ func TestSearchSet(t *testing.T) {
266266
require.Equal(t, []float64{3, 3, 1, 4, 3, 0.5, 5, 4}, searchSet.FindBestDistances(distances[:8]))
267267
require.Equal(t, []float64{3, 6, 1, 4, 3, 6, 5, 4, 3, 0.5}, searchSet.FindBestDistances(distances[:12]))
268268
})
269+
270+
t.Run("test ExcludedPartitions", func(t *testing.T) {
271+
result9 := SearchResult{
272+
QueryDistance: 9, ErrorBound: 1, ParentPartitionKey: 500, ChildKey: ChildKey{KeyBytes: []byte{90}}}
273+
searchSet := SearchSet{MaxResults: 9, ExcludedPartitions: []PartitionKey{100, 500, 800}}
274+
searchSet.AddAll(SearchResults{
275+
result1, result2, result3, result4, result1, result5, result6, result7, result1, result8, result9})
276+
require.Equal(t, SearchResults{result3, result4, result7, result6, result2}, searchSet.PopResults())
277+
278+
set1 := SearchSet{MaxResults: 3, ExcludedPartitions: []PartitionKey{200}}
279+
set2 := SearchSet{MaxResults: 3}
280+
set2.AddAll(SearchResults{result1, result2, result3})
281+
set1.AddSet(&set2)
282+
require.Equal(t, SearchResults{result3, result1}, set1.PopResults())
283+
})
269284
}

0 commit comments

Comments
 (0)