Skip to content

Commit 5775bd6

Browse files
committed
cspann: fix partition data sorting bug
When sorting partition data into left and right groupings, vectors can be sorted in a different order than associated child keys and value bytes. This commit updates the logic to operate on all the partition data at once, not just the vectors. Epic: CRDB-42943 Release note: None
1 parent c712dba commit 5775bd6

File tree

8 files changed

+341
-221
lines changed

8 files changed

+341
-221
lines changed

pkg/sql/vecindex/cspann/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ go_test(
5353
"childkey_dedup_test.go",
5454
"cspannpb_test.go",
5555
"fixup_processor_test.go",
56+
"fixup_split_test.go",
5657
"fixup_worker_test.go",
5758
"index_stats_test.go",
5859
"index_test.go",

pkg/sql/vecindex/cspann/fixup_split.go

Lines changed: 62 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ func (fw *fixupWorker) splitPartition(
139139
}
140140

141141
log.VEventf(ctx, 2, "splitting partition %d with %d vectors (parent=%d, state=%s)",
142-
partitionKey, parentPartitionKey, partition.Count(), metadata.StateDetails.String())
142+
partitionKey, partition.Count(), parentPartitionKey, metadata.StateDetails.String())
143143

144144
// Update partition's state to Splitting.
145145
if metadata.StateDetails.State == ReadyState {
@@ -454,13 +454,13 @@ func (fw *fixupWorker) createSplitSubPartition(
454454
partitionKey PartitionKey,
455455
centroid vector.T,
456456
) (targetMetadata PartitionMetadata, err error) {
457-
const format = "creating split sub-partition %d, with parent %d"
457+
const format = "creating split sub-partition %d (source=%d, parent=%d)"
458458

459459
defer func() {
460-
err = errors.Wrapf(err, format, partitionKey, parentPartitionKey)
460+
err = errors.Wrapf(err, format, partitionKey, sourcePartitionKey, parentPartitionKey)
461461
}()
462462

463-
log.VEventf(ctx, 2, format, partitionKey, parentPartitionKey)
463+
log.VEventf(ctx, 2, format, partitionKey, sourcePartitionKey, parentPartitionKey)
464464

465465
// Create an empty partition in the Updating state.
466466
targetMetadata = PartitionMetadata{
@@ -548,15 +548,15 @@ func (fw *fixupWorker) addToParentPartition(
548548
func (fw *fixupWorker) deletePartition(
549549
ctx context.Context, parentPartitionKey, partitionKey PartitionKey,
550550
) (err error) {
551-
const format = "deleting partition %d (parent=%d, state=%d)"
551+
const format = "deleting partition %d (parent=%d, state=%s)"
552552
var parentMetadata PartitionMetadata
553553

554554
defer func() {
555555
err = errors.Wrapf(err, format,
556-
partitionKey, parentPartitionKey, parentMetadata.StateDetails.State)
556+
partitionKey, parentPartitionKey, parentMetadata.StateDetails.String())
557557
}()
558558

559-
// Load parent partition to verify that it's in a state that allows inserts.
559+
// Load parent partition to verify that it's in a state that allows removes.
560560
var parentPartition *Partition
561561
parentPartition, err = fw.getPartition(ctx, parentPartitionKey)
562562
if err != nil {
@@ -566,7 +566,8 @@ func (fw *fixupWorker) deletePartition(
566566
parentMetadata = *parentPartition.Metadata()
567567
}
568568

569-
log.VEventf(ctx, 2, format, partitionKey, parentPartitionKey, parentMetadata.StateDetails.State)
569+
log.VEventf(ctx, 2, format,
570+
partitionKey, parentPartitionKey, parentMetadata.StateDetails.String())
570571

571572
if !parentMetadata.StateDetails.State.AllowAddOrRemove() {
572573
// Child could not be removed from the parent because it doesn't exist or
@@ -640,21 +641,16 @@ func (fw *fixupWorker) copyToSplitSubPartitions(
640641
leftOffsets, rightOffsets = kmeans.AssignPartitions(
641642
vectors, leftMetadata.Centroid, rightMetadata.Centroid, tempOffsets)
642643

643-
// Sort vectors into contiguous left and right groupings.
644-
sortVectors(&fw.workspace, vectors, leftOffsets, rightOffsets)
644+
// Assign vectors and associated keys and values into contiguous left and right groupings.
645+
childKeys := slices.Clone(sourcePartition.ChildKeys())
646+
valueBytes := slices.Clone(sourcePartition.ValueBytes())
647+
splitPartitionData(&fw.workspace, vectors, childKeys, valueBytes, leftOffsets, rightOffsets)
645648
leftVectors := vectors
646649
rightVectors := leftVectors.SplitAt(len(leftOffsets))
647-
648-
childKeys := make([]ChildKey, vectors.Count)
649-
valueBytes := make([]ValueBytes, vectors.Count)
650-
leftChildKeys := copyByOffsets(
651-
sourcePartition.ChildKeys(), childKeys[:len(leftOffsets)], leftOffsets)
652-
rightChildKeys := copyByOffsets(
653-
sourcePartition.ChildKeys(), childKeys[len(leftOffsets):], rightOffsets)
654-
leftValueBytes := copyByOffsets(
655-
sourcePartition.ValueBytes(), valueBytes[:len(leftOffsets)], leftOffsets)
656-
rightValueBytes := copyByOffsets(
657-
sourcePartition.ValueBytes(), valueBytes[len(leftOffsets):], rightOffsets)
650+
leftChildKeys := childKeys[:len(leftOffsets)]
651+
rightChildKeys := childKeys[len(leftOffsets):]
652+
leftValueBytes := valueBytes[:len(leftOffsets)]
653+
rightValueBytes := valueBytes[len(leftOffsets):]
658654

659655
log.VEventf(ctx, 2, format,
660656
len(leftOffsets), sourceState.Target1, len(rightOffsets), sourceState.Target2)
@@ -740,63 +736,63 @@ func suppressRaceErrors(err error) (PartitionMetadata, error) {
740736
return PartitionMetadata{}, err
741737
}
742738

743-
// sortVectors sorts the input vectors in-place, according to the provided left
744-
// and right offsets, which reference vectors by position. Vectors at left
745-
// offsets are sorted at the beginning of the slice, followed by vectors at
746-
// right offsets. The internal ordering among left and right vectors is not
747-
// defined.
739+
// splitPartitionData groups the provided partition data according to the left
740+
// and right offsets. All data referenced by left offsets will be moved to the
741+
// left of each set or slice. All data referenced by right offsets will be moved
742+
// to the right. The internal ordering of elements on each side is not defined.
748743
//
749-
// NOTE: The left and right offsets are modified in-place with the updated
750-
// positions of the vectors.
751-
func sortVectors(w *workspace.T, vectors vector.Set, leftOffsets, rightOffsets []uint64) {
744+
// TODO(andyk): Passing in left and right offsets makes this overly complex. It
745+
// would be better to pass an assignments slice of the same length as the
746+
// partition data, where 0=left and 1=right.
747+
func splitPartitionData(
748+
w *workspace.T,
749+
vectors vector.Set,
750+
childKeys []ChildKey,
751+
valueBytes []ValueBytes,
752+
leftOffsets, rightOffsets []uint64,
753+
) {
752754
tempVector := w.AllocFloats(vectors.Dims)
753755
defer w.FreeFloats(tempVector)
754756

755-
// Sort left and right offsets.
756-
slices.Sort(leftOffsets)
757-
slices.Sort(rightOffsets)
758-
759-
// Any left offsets that point beyond the end of the left list indicate that
760-
// a vector needs to be moved from the right half of vectors to the left half.
761-
// The reverse is true for right offsets. Because the left and right offsets
762-
// are in sorted order, out-of-bounds offsets must be at the end of the left
763-
// list and the beginning of the right list. Therefore, the algorithm just
764-
// needs to iterate over those out-of-bounds offsets and swap the positions
765-
// of the referenced vectors.
766-
li := len(leftOffsets) - 1
767-
ri := 0
768-
769-
var rightToLeft, leftToRight vector.T
770-
for li >= 0 {
771-
left := int(leftOffsets[li])
772-
if left < len(leftOffsets) {
773-
break
757+
left := 0
758+
right := 0
759+
for {
760+
// Find a misplaced "right" element from the left side.
761+
var leftOffset int
762+
for {
763+
if left >= len(leftOffsets) {
764+
return
765+
}
766+
leftOffset = int(leftOffsets[left])
767+
left++
768+
if leftOffset >= len(leftOffsets) {
769+
break
770+
}
774771
}
775772

776-
right := int(rightOffsets[ri])
777-
if right >= len(leftOffsets) {
778-
panic(errors.AssertionFailedf(
779-
"expected equal number of left and right offsets that need to be swapped"))
773+
// There must be a misplaced "left" element from the right side.
774+
var rightOffset int
775+
for {
776+
rightOffset = int(rightOffsets[right])
777+
right++
778+
if rightOffset < len(leftOffsets) {
779+
break
780+
}
780781
}
781782

782-
// Swap vectors.
783-
rightToLeft = vectors.At(left)
784-
leftToRight = vectors.At(right)
783+
// Swap the two elements.
784+
rightToLeft := vectors.At(leftOffset)
785+
leftToRight := vectors.At(rightOffset)
785786
copy(tempVector, rightToLeft)
786787
copy(rightToLeft, leftToRight)
787788
copy(leftToRight, tempVector)
788789

789-
leftOffsets[li] = uint64(left)
790-
rightOffsets[ri] = uint64(right)
791-
792-
li--
793-
ri++
794-
}
795-
}
790+
tempChildKey := childKeys[leftOffset]
791+
childKeys[leftOffset] = childKeys[rightOffset]
792+
childKeys[rightOffset] = tempChildKey
796793

797-
func copyByOffsets[T any](source, target []T, offsets []uint64) []T {
798-
for i := range offsets {
799-
target[i] = source[offsets[i]]
794+
tempValueBytes := valueBytes[leftOffset]
795+
valueBytes[leftOffset] = valueBytes[rightOffset]
796+
valueBytes[rightOffset] = tempValueBytes
800797
}
801-
return target
802798
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package cspann
7+
8+
import (
9+
"slices"
10+
"testing"
11+
12+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/workspace"
13+
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
14+
"github.com/cockroachdb/cockroach/pkg/util/log"
15+
"github.com/cockroachdb/cockroach/pkg/util/vector"
16+
"github.com/stretchr/testify/require"
17+
)
18+
19+
func TestSplitPartitionData(t *testing.T) {
20+
defer leaktest.AfterTest(t)()
21+
defer log.Scope(t).Close(t)
22+
23+
var workspace workspace.T
24+
vectors := vector.MakeSetFromRawData([]float32{
25+
0, 0,
26+
1, 1,
27+
2, 3,
28+
3, 3,
29+
4, 4,
30+
5, 5,
31+
6, 6,
32+
}, 2)
33+
34+
childKeys := []ChildKey{
35+
{KeyBytes: KeyBytes("vec1")},
36+
{KeyBytes: KeyBytes("vec2")},
37+
{KeyBytes: KeyBytes("vec3")},
38+
{KeyBytes: KeyBytes("vec4")},
39+
{KeyBytes: KeyBytes("vec5")},
40+
{KeyBytes: KeyBytes("vec6")},
41+
{KeyBytes: KeyBytes("vec7")},
42+
}
43+
valueBytes := []ValueBytes{
44+
{1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {6, 6}, {7, 7},
45+
}
46+
47+
testCases := []struct {
48+
desc string
49+
leftOffsets []uint64
50+
rightOffsets []uint64
51+
expectedLeft []uint64
52+
expectedRight []uint64
53+
}{
54+
{
55+
desc: "no reordering",
56+
leftOffsets: []uint64{0, 1, 2, 3},
57+
rightOffsets: []uint64{4, 5, 6},
58+
},
59+
{
60+
desc: "only one on left",
61+
leftOffsets: []uint64{1},
62+
rightOffsets: []uint64{0, 2, 3, 4, 5, 6},
63+
},
64+
{
65+
desc: "only one on right",
66+
leftOffsets: []uint64{0, 1, 2, 4, 5, 6},
67+
rightOffsets: []uint64{3},
68+
},
69+
{
70+
desc: "interleaved",
71+
leftOffsets: []uint64{0, 2, 4, 6},
72+
rightOffsets: []uint64{1, 3, 5},
73+
},
74+
{
75+
desc: "another interleaved",
76+
leftOffsets: []uint64{1, 4, 5},
77+
rightOffsets: []uint64{0, 2, 3, 6},
78+
},
79+
{
80+
desc: "reversed",
81+
leftOffsets: []uint64{4, 5, 6},
82+
rightOffsets: []uint64{0, 1, 2, 3},
83+
},
84+
{
85+
desc: "out of order",
86+
leftOffsets: []uint64{5, 4, 6},
87+
rightOffsets: []uint64{3, 0, 1, 2},
88+
},
89+
}
90+
91+
findKey := func(allKeys []ChildKey, toFind ChildKey) int {
92+
for i, key := range allKeys {
93+
if key.Equal(toFind) {
94+
return i
95+
}
96+
}
97+
return -1
98+
}
99+
100+
for _, tc := range testCases {
101+
t.Run(tc.desc, func(t *testing.T) {
102+
tempVectors := vectors.Clone()
103+
tempChildKeys := slices.Clone(childKeys)
104+
tempValueBytes := slices.Clone(valueBytes)
105+
splitPartitionData(&workspace,
106+
tempVectors, tempChildKeys, tempValueBytes, tc.leftOffsets, tc.rightOffsets)
107+
108+
// Ensure that partition data is on the correct side.
109+
for originalOffset := range childKeys {
110+
newOffset := findKey(tempChildKeys, childKeys[originalOffset])
111+
112+
if newOffset < len(tc.leftOffsets) {
113+
require.Contains(t, tc.leftOffsets, uint64(originalOffset))
114+
} else {
115+
require.Contains(t, tc.rightOffsets, uint64(originalOffset))
116+
}
117+
118+
require.Equal(t, tempVectors.At(newOffset), vectors.At(originalOffset))
119+
require.Equal(t, tempValueBytes[newOffset], valueBytes[originalOffset])
120+
}
121+
})
122+
}
123+
}

pkg/sql/vecindex/cspann/fixup_worker.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ func (fw *fixupWorker) oldSplitPartition(
278278
tempLeftOffsets, tempRightOffsets := kmeans.ComputeCentroids(
279279
vectors, tempLeftCentroid, tempRightCentroid, false /* pinLeftCentroid */, tempOffsets)
280280

281-
leftSplit, rightSplit := splitPartitionData(
281+
leftSplit, rightSplit := oldSplitPartitionData(
282282
&fw.workspace, fw.index.quantizer, partition, vectors,
283283
tempLeftOffsets, tempRightOffsets)
284284

@@ -424,7 +424,7 @@ func (fw *fixupWorker) oldSplitPartition(
424424
// NOTE: The vectors set will be updated in-place, via a partial sort that moves
425425
// vectors in the left partition to the left side of the set. However, the split
426426
// partition is not modified.
427-
func splitPartitionData(
427+
func oldSplitPartitionData(
428428
w *workspace.T,
429429
quantizer quantize.Quantizer,
430430
splitPartition *Partition,

pkg/sql/vecindex/cspann/fixup_worker_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
"github.com/stretchr/testify/require"
1818
)
1919

20-
func TestSplitPartitionData(t *testing.T) {
20+
func TestOldSplitPartitionData(t *testing.T) {
2121
defer leaktest.AfterTest(t)()
2222
defer log.Scope(t).Close(t)
2323

@@ -132,7 +132,7 @@ func TestSplitPartitionData(t *testing.T) {
132132
t.Run(tc.desc, func(t *testing.T) {
133133
tempVectors := vector.MakeSet(2)
134134
tempVectors.AddSet(vectors)
135-
leftSplit, rightSplit := splitPartitionData(
135+
leftSplit, rightSplit := oldSplitPartitionData(
136136
&workspace, quantizer, splitPartition, tempVectors, tc.leftOffsets, tc.rightOffsets)
137137

138138
validate(&leftSplit, tc.expectedLeft)

0 commit comments

Comments
 (0)