Skip to content

Commit a226eb5

Browse files
committed
cspann: add utils.ReplaceWithLast helper function
Add generic ReplaceWithLast function that removes an element from a slice by replacing it with the last element and truncating the slice. Epic: CRDB-42943 Release note: None
1 parent 75e1394 commit a226eb5

File tree

11 files changed

+108
-43
lines changed

11 files changed

+108
-43
lines changed

pkg/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ ALL_TESTS = [
658658
"//pkg/sql/types:types_test",
659659
"//pkg/sql/vecindex/cspann/memstore:memstore_test",
660660
"//pkg/sql/vecindex/cspann/quantize:quantize_test",
661+
"//pkg/sql/vecindex/cspann/utils:utils_test",
661662
"//pkg/sql/vecindex/cspann/workspace:workspace_test",
662663
"//pkg/sql/vecindex/cspann:cspann_test",
663664
"//pkg/sql/vecindex/vecencoding:vecencoding_test",
@@ -2384,6 +2385,7 @@ GO_TARGETS = [
23842385
"//pkg/sql/vecindex/cspann/quantize:quantize_test",
23852386
"//pkg/sql/vecindex/cspann/testutils:testutils",
23862387
"//pkg/sql/vecindex/cspann/utils:utils",
2388+
"//pkg/sql/vecindex/cspann/utils:utils_test",
23872389
"//pkg/sql/vecindex/cspann/workspace:workspace",
23882390
"//pkg/sql/vecindex/cspann/workspace:workspace_test",
23892391
"//pkg/sql/vecindex/cspann:cspann",

pkg/sql/vecindex/cspann/fixup_split.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"math"
1111
"slices"
1212

13+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/utils"
1314
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/workspace"
1415
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
1516
"github.com/cockroachdb/cockroach/pkg/util/log"
@@ -403,7 +404,7 @@ func (fw *fixupWorker) reassignToSiblings(
403404
fw.tempMetadataToGet = fw.tempMetadataToGet[:0]
404405
getSiblingMetadata := func() ([]PartitionMetadataToGet, error) {
405406
if len(fw.tempMetadataToGet) == 0 {
406-
fw.tempMetadataToGet = ensureSliceLen(fw.tempMetadataToGet, parentPartition.Count())
407+
fw.tempMetadataToGet = utils.EnsureSliceLen(fw.tempMetadataToGet, parentPartition.Count())
407408
for i := range len(fw.tempMetadataToGet) {
408409
fw.tempMetadataToGet[i].Key = parentPartition.ChildKeys()[i].PartitionKey
409410
}
@@ -519,7 +520,7 @@ func (fw *fixupWorker) getPartition(
519520
func (fw *fixupWorker) getPartitionMetadata(
520521
ctx context.Context, partitionKey PartitionKey,
521522
) (PartitionMetadata, error) {
522-
fw.tempMetadataToGet = ensureSliceLen(fw.tempMetadataToGet, 1)
523+
fw.tempMetadataToGet = utils.EnsureSliceLen(fw.tempMetadataToGet, 1)
523524
fw.tempMetadataToGet[0].Key = partitionKey
524525
err := fw.index.store.TryGetPartitionMetadata(ctx, fw.treeKey, fw.tempMetadataToGet)
525526
if err != nil {

pkg/sql/vecindex/cspann/fixup_worker.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"math/rand"
1111

12+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/utils"
1213
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/workspace"
1314
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
1415
"github.com/cockroachdb/cockroach/pkg/util/log"
@@ -169,7 +170,7 @@ func (fw *fixupWorker) deleteVector(
169170
// against a race condition where a row is created and deleted repeatedly with
170171
// the same primary key.
171172
childKey := ChildKey{KeyBytes: vectorKey}
172-
fw.tempVectorsWithKeys = ensureSliceLen(fw.tempVectorsWithKeys, 1)
173+
fw.tempVectorsWithKeys = utils.EnsureSliceLen(fw.tempVectorsWithKeys, 1)
173174
fw.tempVectorsWithKeys[0] = VectorWithKey{Key: childKey}
174175
if err = txn.GetFullVectors(ctx, fw.treeKey, fw.tempVectorsWithKeys); err != nil {
175176
return errors.Wrap(err, "getting full vector")
@@ -213,7 +214,7 @@ func (fw *fixupWorker) getFullVectorsForPartition(
213214

214215
err = fw.index.store.RunTransaction(ctx, func(txn Txn) error {
215216
childKeys := partition.ChildKeys()
216-
fw.tempVectorsWithKeys = ensureSliceLen(fw.tempVectorsWithKeys, len(childKeys))
217+
fw.tempVectorsWithKeys = utils.EnsureSliceLen(fw.tempVectorsWithKeys, len(childKeys))
217218
for i := range childKeys {
218219
fw.tempVectorsWithKeys[i] = VectorWithKey{Key: childKeys[i]}
219220
}

pkg/sql/vecindex/cspann/index.go

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ func (vi *Index) findExactDistances(
921921
}
922922

923923
// Prepare vector references.
924-
idxCtx.tempVectorsWithKeys = ensureSliceLen(idxCtx.tempVectorsWithKeys, len(candidates))
924+
idxCtx.tempVectorsWithKeys = utils.EnsureSliceLen(idxCtx.tempVectorsWithKeys, len(candidates))
925925
for i := range candidates {
926926
idxCtx.tempVectorsWithKeys[i].Key = candidates[i].ChildKey
927927
}
@@ -950,9 +950,7 @@ func (vi *Index) findExactDistances(
950950
// Move the last candidate to the current position and reduce size
951951
// of slice by one.
952952
idxCtx.tempVectorsWithKeys[i] = idxCtx.tempVectorsWithKeys[len(candidates)-1]
953-
candidates[i] = candidates[len(candidates)-1]
954-
candidates[len(candidates)-1] = SearchResult{} // for GC
955-
candidates = candidates[:len(candidates)-1]
953+
candidates = utils.ReplaceWithLast(candidates, i)
956954
} else {
957955
i++
958956
}
@@ -1144,17 +1142,3 @@ func (vi *Index) Format(
11441142
}
11451143
return buf.String(), nil
11461144
}
1147-
1148-
// ensureSliceLen returns a slice of the given length and generic type. If the
1149-
// existing slice has enough capacity, that slice is returned after adjusting
1150-
// its length. Otherwise, a new, larger slice is allocated.
1151-
// NOTE: Every element of the new slice is uninitialized; callers are
1152-
// responsible for initializing the memory.
1153-
func ensureSliceLen[T any](s []T, l int) []T {
1154-
// In test builds, always allocate new memory, to catch bugs where callers
1155-
// assume existing slice elements will be copied.
1156-
if cap(s) < l || buildutil.CrdbTestBuild {
1157-
return make([]T, l)
1158-
}
1159-
return s[:l]
1160-
}

pkg/sql/vecindex/cspann/partition.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"slices"
1010

1111
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/quantize"
12+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/utils"
1213
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/workspace"
1314
"github.com/cockroachdb/cockroach/pkg/util/vector"
1415
)
@@ -235,13 +236,8 @@ func (p *Partition) AddSet(
235236
// position changes.
236237
func (p *Partition) ReplaceWithLast(offset int) {
237238
p.quantizedSet.ReplaceWithLast(offset)
238-
newCount := len(p.childKeys) - 1
239-
p.childKeys[offset] = p.childKeys[newCount]
240-
p.childKeys[newCount] = ChildKey{} // for GC
241-
p.childKeys = p.childKeys[:newCount]
242-
p.valueBytes[offset] = p.valueBytes[newCount]
243-
p.valueBytes[newCount] = nil // for GC
244-
p.valueBytes = p.valueBytes[:newCount]
239+
p.childKeys = utils.ReplaceWithLast(p.childKeys, offset)
240+
p.valueBytes = utils.ReplaceWithLast(p.valueBytes, offset)
245241
}
246242

247243
// ReplaceWithLastByKey calls ReplaceWithLast with the offset of the given child

pkg/sql/vecindex/cspann/quantize/rabitqpb.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
math "math"
1010
"slices"
1111

12+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/utils"
1213
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
1314
"github.com/cockroachdb/cockroach/pkg/util/buildutil"
1415
"github.com/cockroachdb/cockroach/pkg/util/num32"
@@ -117,18 +118,13 @@ func (vs *RaBitQuantizedVectorSet) GetCount() int {
117118

118119
// ReplaceWithLast implements the QuantizedVectorSet interface.
119120
func (vs *RaBitQuantizedVectorSet) ReplaceWithLast(offset int) {
120-
lastOffset := len(vs.CodeCounts) - 1
121121
vs.Codes.ReplaceWithLast(offset)
122-
vs.CodeCounts[offset] = vs.CodeCounts[lastOffset]
123-
vs.CodeCounts = vs.CodeCounts[:lastOffset]
124-
vs.CentroidDistances[offset] = vs.CentroidDistances[lastOffset]
125-
vs.CentroidDistances = vs.CentroidDistances[:lastOffset]
126-
vs.QuantizedDotProducts[offset] = vs.QuantizedDotProducts[lastOffset]
127-
vs.QuantizedDotProducts = vs.QuantizedDotProducts[:lastOffset]
122+
vs.CodeCounts = utils.ReplaceWithLast(vs.CodeCounts, offset)
123+
vs.CentroidDistances = utils.ReplaceWithLast(vs.CentroidDistances, offset)
124+
vs.QuantizedDotProducts = utils.ReplaceWithLast(vs.QuantizedDotProducts, offset)
128125
if vs.CentroidDotProducts != nil {
129126
// This is nil for the L2Squared distance metric.
130-
vs.CentroidDotProducts[offset] = vs.CentroidDotProducts[lastOffset]
131-
vs.CentroidDotProducts = vs.CentroidDotProducts[:lastOffset]
127+
vs.CentroidDotProducts = utils.ReplaceWithLast(vs.CentroidDotProducts, offset)
132128
}
133129
}
134130

pkg/sql/vecindex/cspann/query_comparer.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package cspann
77

88
import (
9+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/utils"
910
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
1011
"github.com/cockroachdb/cockroach/pkg/util/num32"
1112
"github.com/cockroachdb/cockroach/pkg/util/vector"
@@ -35,7 +36,7 @@ func (c *queryComparer) Init(
3536
c.original = queryVector
3637

3738
// Randomize the original query vector.
38-
c.randomized = ensureSliceLen(c.randomized, len(queryVector))
39+
c.randomized = utils.EnsureSliceLen(c.randomized, len(queryVector))
3940
c.randomized = rot.RandomizeVector(queryVector, c.randomized)
4041

4142
// If using cosine distance, also normalize the query vector.

pkg/sql/vecindex/cspann/searcher.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"math"
1111

12+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/utils"
1213
"github.com/cockroachdb/errors"
1314
)
1415

@@ -99,14 +100,14 @@ func (s *searcher) Next(ctx context.Context) (ok bool, err error) {
99100
ChildKey: ChildKey{PartitionKey: RootKey},
100101
})
101102
// Ensure that if Next() is called again, it will return false.
102-
s.levels = ensureSliceLen(s.levels, 1)
103+
s.levels = utils.EnsureSliceLen(s.levels, 1)
103104
s.levels[0] = *root
104105
return root.NextBatch(ctx)
105106
}
106107

107108
// Set up remainder of searchers now that we know the root's level.
108109
n := int(root.Level()-s.idxCtx.level) + 1
109-
s.levels = ensureSliceLen(s.levels, n)
110+
s.levels = utils.EnsureSliceLen(s.levels, n)
110111
s.levels[0] = *root
111112
for i := 1; i < n; i++ {
112113
var maxResults, maxExtraResults int
@@ -381,7 +382,7 @@ func (s *levelSearcher) searchChildPartitions(
381382
return InvalidLevel, nil
382383
}
383384

384-
s.idxCtx.tempToSearch = ensureSliceLen(s.idxCtx.tempToSearch, len(parentResults))
385+
s.idxCtx.tempToSearch = utils.EnsureSliceLen(s.idxCtx.tempToSearch, len(parentResults))
385386
for i := range parentResults {
386387
// If this is an Insert or SearchForInsert operation, then do not scan
387388
// leaf vectors. Insert operations never need leaf vectors and scanning

pkg/sql/vecindex/cspann/utils/BUILD.bazel

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
load("@io_bazel_rules_go//go:def.bzl", "go_library")
1+
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
22

33
go_library(
44
name = "utils",
55
srcs = [
66
"format.go",
7+
"slice.go",
78
"validate.go",
89
],
910
importpath = "github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/utils",
@@ -16,3 +17,10 @@ go_library(
1617
"@org_gonum_v1_gonum//floats/scalar",
1718
],
1819
)
20+
21+
go_test(
22+
name = "utils_test",
23+
srcs = ["slice_test.go"],
24+
embed = [":utils"],
25+
deps = ["@com_github_stretchr_testify//require"],
26+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package utils
7+
8+
import "github.com/cockroachdb/cockroach/pkg/util/buildutil"
9+
10+
// ReplaceWithLast removes an element from a slice by replacing it with the last
11+
// element and truncating the slice. This operation preserves the order of all
12+
// elements except the removed one, which is replaced by the last element.
13+
// The operation completes in O(1) time.
14+
func ReplaceWithLast[T any](s []T, i int) []T {
15+
l := len(s) - 1
16+
s[i] = s[l]
17+
return s[:l]
18+
}
19+
20+
// EnsureSliceLen returns a slice of the given length and generic type. If the
21+
// existing slice has enough capacity, that slice is returned after adjusting
22+
// its length. Otherwise, a new, larger slice is allocated.
23+
// NOTE: Every element of the new slice is uninitialized; callers are
24+
// responsible for initializing the memory.
25+
func EnsureSliceLen[T any](s []T, l int) []T {
26+
// In test builds, always allocate new memory, to catch bugs where callers
27+
// assume existing slice elements will be copied.
28+
if cap(s) < l || buildutil.CrdbTestBuild {
29+
return make([]T, l)
30+
}
31+
return s[:l]
32+
}

0 commit comments

Comments
 (0)