Skip to content

Commit fd3daf9

Browse files
committed
cspann: move vectors from sibling partitions during split
When a partition is split, the left and right target partitions are assigned new centroids. This makes it possible for vectors in other partitions at the same level to now be closer to one of those new centroids than they are to their own centroid. In that case, we need to move those vectors to whichever target partition is now closer. Epic: CRDB-42943 Release note: None
1 parent c75a879 commit fd3daf9

17 files changed

+1003
-824
lines changed

pkg/sql/vecindex/cspann/fixup_processor.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ func (fp *FixupProcessor) nextFixup(ctx context.Context) (next fixup, ok bool) {
486486
}()
487487
}
488488

489-
if discard {
489+
// Always process fixup if it's single-stepping.
490+
if discard && !next.SingleStep {
490491
fp.removeFixup(next)
491492
continue
492493
}

pkg/sql/vecindex/cspann/fixup_split.go

Lines changed: 192 additions & 47 deletions
Large diffs are not rendered by default.

pkg/sql/vecindex/cspann/fixup_worker.go

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ type fixupWorker struct {
9797
tempChildKey [1]ChildKey
9898
tempValueBytes [1]ValueBytes
9999
tempMetadataToGet []PartitionMetadataToGet
100+
tempIndexCtx Context
101+
tempPartitionKeys [3]PartitionKey
100102
}
101103

102104
// ewFixupWorker returns a new worker for the given processor.
@@ -240,24 +242,33 @@ func (fw *fixupWorker) getFullVectorsForPartition(
240242
vectors = vector.MakeSet(fw.index.quantizer.GetDims())
241243
vectors.AddUndefined(len(fw.tempVectorsWithKeys))
242244
for i := range fw.tempVectorsWithKeys {
243-
if partition.Level() == LeafLevel {
244-
// Leaf vectors from the primary index need to be randomized and
245-
// possibly normalized.
246-
fw.index.TransformVector(fw.tempVectorsWithKeys[i].Vector, vectors.At(i))
247-
} else {
248-
copy(vectors.At(i), fw.tempVectorsWithKeys[i].Vector)
249-
250-
// Convert mean centroids into spherical centroids for the Cosine
251-
// and InnerProduct distance metrics.
252-
switch fw.index.quantizer.GetDistanceMetric() {
253-
case vecpb.CosineDistance, vecpb.InnerProductDistance:
254-
num32.Normalize(vectors.At(i))
255-
}
256-
}
245+
fw.transformFullVector(partition.Level(), fw.tempVectorsWithKeys[i].Vector, vectors.At(i))
257246
}
258247

259248
return nil
260249
})
261250

262251
return vectors, err
263252
}
253+
254+
// transformFullVector ensures that the full vector fetched from a partition at
255+
// the given level has been properly randomized and normalized. It copies the
256+
// randomized, normalized vector into "randomized", which must be allocated by
257+
// the caller with the same length as the input vector.
258+
func (fw *fixupWorker) transformFullVector(level Level, vec, randomized vector.T) {
259+
if level == LeafLevel {
260+
// Leaf vectors from the primary index need to be randomized and possibly
261+
// normalized.
262+
fw.index.TransformVector(vec, randomized)
263+
} else {
264+
// This is an interior level, which means the vector is a partition
265+
// centroid that's already normalized. However, it's a mean centroid, and
266+
// needs to be converted into a spherical centroid for the Cosine and
267+
// InnerProduct distance metrics.
268+
copy(randomized, vec)
269+
switch fw.index.quantizer.GetDistanceMetric() {
270+
case vecpb.CosineDistance, vecpb.InnerProductDistance:
271+
num32.Normalize(randomized)
272+
}
273+
}
274+
}

pkg/sql/vecindex/cspann/index.go

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ func (vi *Index) Search(
474474
searchSet *SearchSet,
475475
options SearchOptions,
476476
) error {
477-
vi.setupContext(idxCtx, treeKey, vec, options, LeafLevel)
477+
vi.setupContext(idxCtx, treeKey, options, LeafLevel, vec, false /* transformed */)
478478
return vi.searchHelper(ctx, idxCtx, searchSet)
479479
}
480480

@@ -595,11 +595,11 @@ func (vi *Index) ForceMerge(
595595
func (vi *Index) setupInsertContext(idxCtx *Context, treeKey TreeKey, vec vector.T) {
596596
// Perform the search using quantized vectors rather than full vectors (i.e.
597597
// skip reranking).
598-
vi.setupContext(idxCtx, treeKey, vec, SearchOptions{
598+
vi.setupContext(idxCtx, treeKey, SearchOptions{
599599
BaseBeamSize: vi.options.BaseBeamSize,
600600
SkipRerank: true,
601601
UpdateStats: !vi.options.DisableAdaptiveSearch,
602-
}, SecondLevel)
602+
}, SecondLevel, vec, false /* transformed */)
603603
idxCtx.forInsert = true
604604
}
605605

@@ -608,27 +608,39 @@ func (vi *Index) setupDeleteContext(idxCtx *Context, treeKey TreeKey, vec vector
608608
// Perform the search using quantized vectors rather than full vectors (i.e.
609609
// skip reranking). Use a larger beam size to make it more likely that we'll
610610
// find the vector to delete.
611-
vi.setupContext(idxCtx, treeKey, vec, SearchOptions{
611+
vi.setupContext(idxCtx, treeKey, SearchOptions{
612612
BaseBeamSize: vi.options.BaseBeamSize * 2,
613613
SkipRerank: true,
614614
UpdateStats: !vi.options.DisableAdaptiveSearch,
615-
}, LeafLevel)
615+
}, LeafLevel, vec, false /* transformed */)
616616
idxCtx.forDelete = true
617617
}
618618

619-
// setupContext sets up the given context as an operation is beginning.
619+
// setupContext sets up the given context as an operation is beginning. If
620+
// "randomized" is false, then the given vector is expected to be an original,
621+
// unrandomized vector that was provided by the user. If "transformed" is true,
622+
// then the vector is expected to already be randomized and normalized.
620623
func (vi *Index) setupContext(
621-
idxCtx *Context, treeKey TreeKey, vec vector.T, options SearchOptions, level Level,
624+
idxCtx *Context,
625+
treeKey TreeKey,
626+
options SearchOptions,
627+
level Level,
628+
vec vector.T,
629+
transformed bool,
622630
) {
623631
idxCtx.treeKey = treeKey
624632
idxCtx.level = level
625-
idxCtx.query.Init(vi.quantizer.GetDistanceMetric(), vec, &vi.rot)
626633
idxCtx.forInsert = false
627634
idxCtx.forDelete = false
628635
idxCtx.options = options
629636
if idxCtx.options.BaseBeamSize == 0 {
630637
idxCtx.options.BaseBeamSize = vi.options.BaseBeamSize
631638
}
639+
if transformed {
640+
idxCtx.query.InitTransformed(vi.quantizer.GetDistanceMetric(), vec, &vi.rot)
641+
} else {
642+
idxCtx.query.InitOriginal(vi.quantizer.GetDistanceMetric(), vec, &vi.rot)
643+
}
632644
}
633645

634646
// updateFunc is called by searchForUpdateHelper when it has a candidate
@@ -920,6 +932,25 @@ func (vi *Index) findExactDistances(
920932
return candidates, nil
921933
}
922934

935+
var err error
936+
candidates, err = vi.getFullVectors(ctx, idxCtx, candidates)
937+
if err != nil {
938+
return nil, errors.Wrapf(err, "getting full vectors to find exact distances")
939+
}
940+
941+
// Compute exact distance between query vector and the data vectors.
942+
idxCtx.query.ComputeExactDistances(idxCtx.level, candidates)
943+
944+
return candidates, nil
945+
}
946+
947+
// getFullVectors fetches the full-size vectors for the given search candidates.
948+
// These can be either leaf vectors fetched from the primary index or interior
949+
// partition centroids fetched from the index. Fixups are enqueued for any
950+
// vectors found to be "dangling".
951+
func (vi *Index) getFullVectors(
952+
ctx context.Context, idxCtx *Context, candidates []SearchResult,
953+
) ([]SearchResult, error) {
923954
// Prepare vector references.
924955
idxCtx.tempVectorsWithKeys = utils.EnsureSliceLen(idxCtx.tempVectorsWithKeys, len(candidates))
925956
for i := range candidates {
@@ -956,9 +987,6 @@ func (vi *Index) findExactDistances(
956987
}
957988
}
958989

959-
// Compute exact distance between query vector and the data vectors.
960-
idxCtx.query.ComputeExactDistances(idxCtx.level, candidates)
961-
962990
return candidates, nil
963991
}
964992

pkg/sql/vecindex/cspann/index_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,8 @@ func (s *testState) ForceSplitOrMerge(d *datadriven.TestData) string {
446446

447447
case "steps":
448448
steps = s.parseInt(arg)
449+
// Always discard any fixups triggered by the split or merge fixup.
450+
s.DiscardFixups = true
449451
}
450452
}
451453

pkg/sql/vecindex/cspann/query_comparer.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,35 @@ type queryComparer struct {
2828
randomized vector.T
2929
}
3030

31-
// Init sets the query vector and prepares the comparer for use.
32-
func (c *queryComparer) Init(
33-
distanceMetric vecpb.DistanceMetric, queryVector vector.T, rot *RandomOrthoTransformer,
31+
// InitOriginal sets the original query vector and prepares the comparer for
32+
// use.
33+
func (c *queryComparer) InitOriginal(
34+
distanceMetric vecpb.DistanceMetric, original vector.T, rot *RandomOrthoTransformer,
3435
) {
3536
c.distanceMetric = distanceMetric
36-
c.original = queryVector
37+
c.original = original
3738

3839
// Randomize the original query vector.
39-
c.randomized = utils.EnsureSliceLen(c.randomized, len(queryVector))
40-
c.randomized = rot.RandomizeVector(queryVector, c.randomized)
40+
c.randomized = utils.EnsureSliceLen(c.randomized, len(original))
41+
c.randomized = rot.RandomizeVector(original, c.randomized)
4142

4243
// If using cosine distance, also normalize the query vector.
4344
if c.distanceMetric == vecpb.CosineDistance {
4445
num32.Normalize(c.randomized)
4546
}
4647
}
4748

49+
// InitRandomized sets the transformed query vector in cases where the original
50+
// query vector is not available, such as when the vector is an interior
51+
// partition centroid. It is expected to already be randomized and normalized.
52+
func (c *queryComparer) InitTransformed(
53+
distanceMetric vecpb.DistanceMetric, randomized vector.T, rot *RandomOrthoTransformer,
54+
) {
55+
c.distanceMetric = distanceMetric
56+
c.original = nil
57+
c.randomized = randomized
58+
}
59+
4860
// Randomized returns the query vector after it has been randomized and
4961
// normalized as needed.
5062
func (c *queryComparer) Randomized() vector.T {

pkg/sql/vecindex/cspann/query_comparer_test.go

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
1212
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
1313
"github.com/cockroachdb/cockroach/pkg/util/log"
14+
"github.com/cockroachdb/cockroach/pkg/util/num32"
1415
"github.com/cockroachdb/cockroach/pkg/util/vector"
1516
"github.com/stretchr/testify/require"
1617
)
@@ -158,16 +159,49 @@ func TestQueryComparer(t *testing.T) {
158159
t.Run(tc.name, func(t *testing.T) {
159160
// Setup queryComparer.
160161
var rot RandomOrthoTransformer
161-
rot.Init(vecpb.RotNone, len(tc.queryVector), 42)
162+
rot.Init(vecpb.RotGivens, len(tc.queryVector), 42)
162163

163164
var comparer queryComparer
164-
comparer.Init(tc.metric, tc.queryVector, &rot)
165+
comparer.InitOriginal(tc.metric, tc.queryVector, &rot)
165166

166-
// Make a copy of candidates to avoid modifying test data.
167+
// Make a copy of the candidates.
167168
candidates := make([]SearchResult, len(tc.candidates))
168169
copy(candidates, tc.candidates)
169170

170-
// Test the main method.
171+
// Randomize candidates from an interior level.
172+
if tc.level != LeafLevel {
173+
for i := range tc.candidates {
174+
rot.RandomizeVector(tc.candidates[i].Vector, candidates[i].Vector)
175+
}
176+
}
177+
178+
// Test ComputeExactDistances.
179+
comparer.ComputeExactDistances(tc.level, candidates)
180+
require.Len(t, candidates, len(tc.expected), "number of candidates should be preserved")
181+
182+
for i, expected := range tc.expected {
183+
require.InDelta(t, expected, candidates[i].QueryDistance, 1e-5,
184+
"distance mismatch for candidate %d", i)
185+
186+
// Error bound should always be 0 for exact distances.
187+
require.Equal(t, float32(0), candidates[i].ErrorBound,
188+
"error bound should be 0 for exact distances")
189+
}
190+
191+
// Test InitRandomized for interior levels.
192+
if tc.level == LeafLevel {
193+
return
194+
}
195+
196+
// Transform the query vector.
197+
queryVector := make([]float32, len(tc.queryVector))
198+
rot.RandomizeVector(tc.queryVector, queryVector)
199+
if tc.metric == vecpb.CosineDistance {
200+
num32.Normalize(queryVector)
201+
}
202+
comparer.InitTransformed(tc.metric, queryVector, &rot)
203+
204+
// Test ComputeExactDistances.
171205
comparer.ComputeExactDistances(tc.level, candidates)
172206
require.Len(t, candidates, len(tc.expected), "number of candidates should be preserved")
173207

pkg/sql/vecindex/cspann/searcher.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package cspann
88
import (
99
"context"
1010
"math"
11+
"slices"
1112

1213
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/utils"
1314
"github.com/cockroachdb/errors"
@@ -172,6 +173,8 @@ type levelSearcher struct {
172173
parent *levelSearcher
173174
// stats points to search stats that should be updated as the search runs.
174175
stats *SearchStats
176+
// excludedPartitions specifies which partitions to skip during search.
177+
excludedPartitions []PartitionKey
175178
// level is the K-means tree level being searched. This is undefined for the
176179
// root level until SearchRoot is called.
177180
level Level
@@ -244,7 +247,11 @@ func (s *levelSearcher) Init(
244247
}
245248
s.searchSet.MaxResults = maxResults
246249
s.searchSet.MaxExtraResults = maxExtraResults
250+
251+
// Set additional fields that only apply to the last level.
247252
s.searchSet.MatchKey = searchSet.MatchKey
253+
s.searchSet.IncludeCentroidDistances = searchSet.IncludeCentroidDistances
254+
s.excludedPartitions = searchSet.ExcludedPartitions
248255
}
249256
}
250257
}
@@ -298,12 +305,27 @@ func (s *levelSearcher) NextBatch(ctx context.Context) (ok bool, err error) {
298305
// overlap with the previous batch, in terms of ordering and duplicates.
299306
s.searchSet.Clear()
300307

308+
// filterPartitions filters parent results so that we don't even attempt to
309+
// search excluded partitions.
310+
filterPartitions := func(results []SearchResult) []SearchResult {
311+
if s.excludedPartitions == nil {
312+
return results
313+
}
314+
for i := 0; i < len(results); i++ {
315+
if slices.Contains(s.excludedPartitions, results[i].ChildKey.PartitionKey) {
316+
results = utils.ReplaceWithLast(results, i)
317+
i--
318+
}
319+
}
320+
return results
321+
}
322+
301323
if firstBatch {
302324
ok, err := s.parent.NextBatch(ctx)
303325
if err != nil || !ok {
304326
return ok, err
305327
}
306-
s.parentResults = s.parent.SearchSet().PopResults()
328+
s.parentResults = filterPartitions(s.parent.SearchSet().PopResults())
307329
} else if len(s.parentResults) < s.beamSize {
308330
// Get more results from parent to try and fill the beam size.
309331
parentResults := s.parent.SearchSet().PopResults()
@@ -319,6 +341,7 @@ func (s *levelSearcher) NextBatch(ctx context.Context) (ok bool, err error) {
319341
}
320342
parentResults = s.parent.SearchSet().PopResults()
321343
}
344+
parentResults = filterPartitions(parentResults)
322345
if len(s.parentResults) == 0 {
323346
s.parentResults = parentResults
324347
} else {

0 commit comments

Comments
 (0)