Skip to content

Commit 314510c

Browse files
craig[bot]andy-kimball
andcommitted
Merge #147881
147881: cspann: simplify partition assignment code r=drewkimball a=andy-kimball Previously, partition assignment worked by assigning vectors by their offsets in the vector set. AssignPartitions returned leftOffsets and rightOffsets sets containing the offsets for each partition. The offsets are awkward to handle and there's a TODO to switch this to use an assignments slice instead. This commit switchs AssignPartitions to fill out an assignments slice, where 0 indicates a vector has been assigned to the left partition, or 1 to the right. It also updates all callers to use assignments rather than offsets. Epic: CRDB-42943 Release note: None Co-authored-by: Andrew Kimball <[email protected]>
2 parents dd2376a + 0d7286a commit 314510c

File tree

11 files changed

+463
-487
lines changed

11 files changed

+463
-487
lines changed

pkg/sql/vecindex/cspann/fixup_split.go

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -711,33 +711,33 @@ func (fw *fixupWorker) copyToSplitSubPartitions(
711711
vectors vector.Set,
712712
leftMetadata, rightMetadata PartitionMetadata,
713713
) (err error) {
714-
var leftOffsets, rightOffsets []uint64
714+
var leftCount int
715715
sourceState := sourcePartition.Metadata().StateDetails
716716

717717
defer func() {
718718
err = errors.Wrapf(err,
719719
"assigning %d vectors to left partition %d and %d vectors to right partition %d",
720-
len(leftOffsets), sourceState.Target1, len(rightOffsets), sourceState.Target2)
720+
leftCount, sourceState.Target1, vectors.Count-leftCount, sourceState.Target2)
721721
}()
722722

723-
tempOffsets := fw.workspace.AllocUint64s(vectors.Count)
724-
defer fw.workspace.FreeUint64s(tempOffsets)
723+
tempAssignments := fw.workspace.AllocUint64s(vectors.Count)
724+
defer fw.workspace.FreeUint64s(tempAssignments)
725725

726726
// Assign vectors to the partition with the nearest centroid.
727727
kmeans := BalancedKmeans{Workspace: &fw.workspace, Rand: fw.rng}
728-
leftOffsets, rightOffsets = kmeans.AssignPartitions(
729-
vectors, leftMetadata.Centroid, rightMetadata.Centroid, tempOffsets)
728+
leftCount = kmeans.AssignPartitions(
729+
vectors, leftMetadata.Centroid, rightMetadata.Centroid, tempAssignments)
730730

731731
// Assign vectors and associated keys and values into contiguous left and right groupings.
732732
childKeys := slices.Clone(sourcePartition.ChildKeys())
733733
valueBytes := slices.Clone(sourcePartition.ValueBytes())
734-
splitPartitionData(&fw.workspace, vectors, childKeys, valueBytes, leftOffsets, rightOffsets)
734+
splitPartitionData(&fw.workspace, vectors, childKeys, valueBytes, tempAssignments)
735735
leftVectors := vectors
736-
rightVectors := leftVectors.SplitAt(len(leftOffsets))
737-
leftChildKeys := childKeys[:len(leftOffsets)]
738-
rightChildKeys := childKeys[len(leftOffsets):]
739-
leftValueBytes := valueBytes[:len(leftOffsets)]
740-
rightValueBytes := valueBytes[len(leftOffsets):]
736+
rightVectors := leftVectors.SplitAt(leftCount)
737+
leftChildKeys := childKeys[:leftCount]
738+
rightChildKeys := childKeys[leftCount:]
739+
leftValueBytes := valueBytes[:leftCount]
740+
rightValueBytes := valueBytes[leftCount:]
741741

742742
// Add vectors to left and right sub-partitions. Note that this may not be
743743
// transactional; if an error occurs, any vectors already added may not be
@@ -748,7 +748,7 @@ func (fw *fixupWorker) copyToSplitSubPartitions(
748748
leftPartitionKey, leftVectors, leftChildKeys, leftValueBytes, leftMetadata)
749749
if added {
750750
log.VEventf(ctx, 2, "assigned %d vectors to left partition %d (level=%d, state=%s)",
751-
len(leftOffsets), leftPartitionKey, leftMetadata.Level, leftMetadata.StateDetails.String())
751+
leftCount, leftPartitionKey, leftMetadata.Level, leftMetadata.StateDetails.String())
752752
}
753753
if err != nil {
754754
return err
@@ -768,7 +768,7 @@ func (fw *fixupWorker) copyToSplitSubPartitions(
768768
rightPartitionKey, rightVectors, rightChildKeys, rightValueBytes, rightMetadata)
769769
if added {
770770
log.VEventf(ctx, 2, "assigned %d vectors to right partition %d (level=%d, state=%s)",
771-
len(rightOffsets), rightPartitionKey,
771+
vectors.Count-leftCount, rightPartitionKey,
772772
rightMetadata.Level, rightMetadata.StateDetails.String())
773773
}
774774
if err != nil {
@@ -798,62 +798,55 @@ func suppressRaceErrors(err error) (PartitionMetadata, error) {
798798
}
799799

800800
// splitPartitionData groups the provided partition data according to the left
801-
// and right offsets. All data referenced by left offsets will be moved to the
802-
// left of each set or slice. All data referenced by right offsets will be moved
803-
// to the right. The internal ordering of elements on each side is not defined.
804-
//
805-
// TODO(andyk): Passing in left and right offsets makes this overly complex. It
806-
// would be better to pass an assignments slice of the same length as the
807-
// partition data, where 0=left and 1=right.
801+
// and right offsets. The assignments slice specifies which partition the data
802+
// will be moved into: 0 for left and 1 for right. The internal ordering of
803+
// elements on each side is not defined.
808804
func splitPartitionData(
809805
w *workspace.T,
810806
vectors vector.Set,
811807
childKeys []ChildKey,
812808
valueBytes []ValueBytes,
813-
leftOffsets, rightOffsets []uint64,
809+
assignments []uint64,
814810
) {
815811
tempVector := w.AllocFloats(vectors.Dims)
816812
defer w.FreeFloats(tempVector)
817813

814+
// Use a two-pointer approach to partition the data. left points to the next
815+
// position where a left element should go. right points to the next position
816+
// where a right element should go (from the end).
818817
left := 0
819-
right := 0
818+
right := len(assignments) - 1
819+
820820
for {
821-
// Find a misplaced "right" element from the left side.
822-
var leftOffset int
823-
for {
824-
if left >= len(leftOffsets) {
825-
return
826-
}
827-
leftOffset = int(leftOffsets[left])
821+
// Find a misplaced element on the left side (should be 0 but is 1).
822+
for left < right && assignments[left] == 0 {
828823
left++
829-
if leftOffset >= len(leftOffsets) {
830-
break
831-
}
832824
}
833825

834-
// There must be a misplaced "left" element from the right side.
835-
var rightOffset int
836-
for {
837-
rightOffset = int(rightOffsets[right])
838-
right++
839-
if rightOffset < len(leftOffsets) {
840-
break
841-
}
826+
// Find a misplaced element on the right side (should be 1 but is 0).
827+
for left < right && assignments[right] == 1 {
828+
right--
829+
}
830+
831+
if left >= right {
832+
// No more misplaced elements, so break.
833+
break
842834
}
843835

844-
// Swap the two elements.
845-
rightToLeft := vectors.At(leftOffset)
846-
leftToRight := vectors.At(rightOffset)
847-
copy(tempVector, rightToLeft)
848-
copy(rightToLeft, leftToRight)
849-
copy(leftToRight, tempVector)
836+
// Swap vectors.
837+
leftVector := vectors.At(left)
838+
rightVector := vectors.At(right)
839+
copy(tempVector, leftVector)
840+
copy(leftVector, rightVector)
841+
copy(rightVector, tempVector)
842+
843+
// Swap child keys.
844+
childKeys[left], childKeys[right] = childKeys[right], childKeys[left]
850845

851-
tempChildKey := childKeys[leftOffset]
852-
childKeys[leftOffset] = childKeys[rightOffset]
853-
childKeys[rightOffset] = tempChildKey
846+
// Swap value bytes.
847+
valueBytes[left], valueBytes[right] = valueBytes[right], valueBytes[left]
854848

855-
tempValueBytes := valueBytes[leftOffset]
856-
valueBytes[leftOffset] = valueBytes[rightOffset]
857-
valueBytes[rightOffset] = tempValueBytes
849+
left++
850+
right--
858851
}
859852
}

pkg/sql/vecindex/cspann/fixup_split_test.go

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -45,78 +45,54 @@ func TestSplitPartitionData(t *testing.T) {
4545
}
4646

4747
testCases := []struct {
48-
desc string
49-
leftOffsets []uint64
50-
rightOffsets []uint64
51-
expectedLeft []uint64
52-
expectedRight []uint64
48+
desc string
49+
assignments []uint64
50+
expected []int
5351
}{
5452
{
55-
desc: "no reordering",
56-
leftOffsets: []uint64{0, 1, 2, 3},
57-
rightOffsets: []uint64{4, 5, 6},
53+
desc: "no reordering",
54+
assignments: []uint64{0, 0, 0, 0, 1, 1, 1},
55+
expected: []int{0, 1, 2, 3, 4, 5, 6},
5856
},
5957
{
60-
desc: "only one on left",
61-
leftOffsets: []uint64{1},
62-
rightOffsets: []uint64{0, 2, 3, 4, 5, 6},
58+
desc: "only one on left",
59+
assignments: []uint64{1, 0, 1, 1, 1, 1, 1},
60+
expected: []int{1, 0, 2, 3, 4, 5, 6},
6361
},
6462
{
65-
desc: "only one on right",
66-
leftOffsets: []uint64{0, 1, 2, 4, 5, 6},
67-
rightOffsets: []uint64{3},
63+
desc: "only one on right",
64+
assignments: []uint64{0, 0, 0, 1, 0, 0, 0},
65+
expected: []int{0, 1, 2, 6, 4, 5, 3},
6866
},
6967
{
70-
desc: "interleaved",
71-
leftOffsets: []uint64{0, 2, 4, 6},
72-
rightOffsets: []uint64{1, 3, 5},
68+
desc: "interleaved",
69+
assignments: []uint64{0, 1, 0, 1, 0, 1, 0},
70+
expected: []int{0, 6, 2, 4, 3, 5, 1},
7371
},
7472
{
75-
desc: "another interleaved",
76-
leftOffsets: []uint64{1, 4, 5},
77-
rightOffsets: []uint64{0, 2, 3, 6},
73+
desc: "another interleaved",
74+
assignments: []uint64{1, 0, 1, 1, 0, 0, 1},
75+
expected: []int{5, 1, 4, 3, 2, 0, 6},
7876
},
7977
{
80-
desc: "reversed",
81-
leftOffsets: []uint64{4, 5, 6},
82-
rightOffsets: []uint64{0, 1, 2, 3},
83-
},
84-
{
85-
desc: "out of order",
86-
leftOffsets: []uint64{5, 4, 6},
87-
rightOffsets: []uint64{3, 0, 1, 2},
78+
desc: "reversed",
79+
assignments: []uint64{1, 1, 1, 1, 0, 0, 0},
80+
expected: []int{6, 5, 4, 3, 2, 1, 0},
8881
},
8982
}
9083

91-
findKey := func(allKeys []ChildKey, toFind ChildKey) int {
92-
for i, key := range allKeys {
93-
if key.Equal(toFind) {
94-
return i
95-
}
96-
}
97-
return -1
98-
}
99-
10084
for _, tc := range testCases {
10185
t.Run(tc.desc, func(t *testing.T) {
10286
tempVectors := vectors.Clone()
10387
tempChildKeys := slices.Clone(childKeys)
10488
tempValueBytes := slices.Clone(valueBytes)
105-
splitPartitionData(&workspace,
106-
tempVectors, tempChildKeys, tempValueBytes, tc.leftOffsets, tc.rightOffsets)
89+
splitPartitionData(&workspace, tempVectors, tempChildKeys, tempValueBytes, tc.assignments)
10790

10891
// Ensure that partition data is on the correct side.
109-
for originalOffset := range childKeys {
110-
newOffset := findKey(tempChildKeys, childKeys[originalOffset])
111-
112-
if newOffset < len(tc.leftOffsets) {
113-
require.Contains(t, tc.leftOffsets, uint64(originalOffset))
114-
} else {
115-
require.Contains(t, tc.rightOffsets, uint64(originalOffset))
116-
}
117-
118-
require.Equal(t, tempVectors.At(newOffset), vectors.At(originalOffset))
119-
require.Equal(t, tempValueBytes[newOffset], valueBytes[originalOffset])
92+
for i := range tc.expected {
93+
require.Equal(t, tempVectors.At(tc.expected[i]), vectors.At(i))
94+
require.Equal(t, tempChildKeys[tc.expected[i]], childKeys[i])
95+
require.Equal(t, tempValueBytes[tc.expected[i]], valueBytes[i])
12096
}
12197
})
12298
}

0 commit comments

Comments
 (0)