Skip to content

Commit 1556d47

Browse files
committed
pull in latest changes from spicedb
1 parent 47a34fc commit 1556d47

File tree

9 files changed

+734
-161
lines changed

9 files changed

+734
-161
lines changed

go.mod

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,46 @@
11
module github.com/authzed/crdbpool
22

3-
go 1.20
3+
go 1.24.0
4+
5+
toolchain go1.24.6
46

57
require (
6-
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0-rc.5
7-
github.com/jackc/pgx/v5 v5.3.2-0.20230529162321-9720d0d63faf
8+
github.com/ccoveille/go-safecast v1.6.1
9+
github.com/google/uuid v1.6.0
10+
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2
11+
github.com/jackc/pgx/v5 v5.7.5
812
github.com/lthibault/jitterbug v2.0.0+incompatible
9-
github.com/prometheus/client_golang v1.15.1
10-
github.com/rs/zerolog v1.29.0
11-
github.com/stretchr/testify v1.8.4
12-
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
13-
golang.org/x/sync v0.2.0
14-
golang.org/x/time v0.3.0
13+
github.com/prometheus/client_golang v1.22.0
14+
github.com/rs/zerolog v1.34.0
15+
github.com/stretchr/testify v1.10.0
16+
golang.org/x/sync v0.15.0
17+
golang.org/x/time v0.12.0
18+
google.golang.org/grpc v1.73.0
1519
)
1620

1721
require (
1822
github.com/beorn7/perks v1.0.1 // indirect
19-
github.com/cespare/xxhash/v2 v2.2.0 // indirect
20-
github.com/davecgh/go-spew v1.1.1 // indirect
21-
github.com/golang/protobuf v1.5.3 // indirect
23+
github.com/cespare/xxhash/v2 v2.3.0 // indirect
24+
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
25+
github.com/go-logr/logr v1.4.3 // indirect
2226
github.com/jackc/pgpassfile v1.0.0 // indirect
23-
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
24-
github.com/jackc/puddle/v2 v2.2.0 // indirect
25-
github.com/kr/text v0.2.0 // indirect
26-
github.com/magefile/mage v1.15.0 // indirect
27-
github.com/mattn/go-colorable v0.1.13 // indirect
28-
github.com/mattn/go-isatty v0.0.16 // indirect
29-
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
30-
github.com/pmezard/go-difflib v1.0.0 // indirect
31-
github.com/prometheus/client_model v0.3.0 // indirect
32-
github.com/prometheus/common v0.42.0 // indirect
33-
github.com/prometheus/procfs v0.9.0 // indirect
34-
github.com/rogpeppe/go-internal v1.10.0 // indirect
35-
golang.org/x/crypto v0.9.0 // indirect
36-
golang.org/x/net v0.10.0 // indirect
37-
golang.org/x/sys v0.8.0 // indirect
38-
golang.org/x/text v0.9.0 // indirect
39-
google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 // indirect
40-
google.golang.org/grpc v1.54.0 // indirect
41-
google.golang.org/protobuf v1.30.0 // indirect
27+
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
28+
github.com/jackc/puddle/v2 v2.2.2 // indirect
29+
github.com/mattn/go-colorable v0.1.14 // indirect
30+
github.com/mattn/go-isatty v0.0.20 // indirect
31+
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
32+
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
33+
github.com/prometheus/client_model v0.6.2 // indirect
34+
github.com/prometheus/common v0.64.0 // indirect
35+
github.com/prometheus/procfs v0.15.1 // indirect
36+
github.com/rogpeppe/go-internal v1.14.1 // indirect
37+
go.opentelemetry.io/otel v1.36.0 // indirect
38+
go.opentelemetry.io/otel/sdk/metric v1.36.0 // indirect
39+
golang.org/x/crypto v0.39.0 // indirect
40+
golang.org/x/net v0.41.0 // indirect
41+
golang.org/x/sys v0.33.0 // indirect
42+
golang.org/x/text v0.26.0 // indirect
43+
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
44+
google.golang.org/protobuf v1.36.6 // indirect
4245
gopkg.in/yaml.v3 v3.0.1 // indirect
4346
)

go.sum

Lines changed: 83 additions & 72 deletions
Large diffs are not rendered by default.

pkg/balancer.go

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@ package pool
33
import (
44
"context"
55
"hash/maphash"
6+
"maps"
7+
"math"
68
"math/rand"
9+
"slices"
710
"strconv"
811
"time"
912

13+
"github.com/ccoveille/go-safecast"
1014
"github.com/jackc/pgx/v5"
1115
"github.com/jackc/pgx/v5/pgxpool"
1216
"github.com/prometheus/client_golang/prometheus"
13-
"golang.org/x/exp/maps"
14-
"golang.org/x/exp/slices"
1517
"golang.org/x/sync/semaphore"
1618

1719
log "github.com/rs/zerolog"
@@ -83,14 +85,26 @@ type nodeConnectionBalancer[P balancePoolConn[C], C balanceConn] struct {
8385
// newNodeConnectionBalancer is generic over underlying connection types for
8486
// testing purposes. Callers should use the exported NewNodeConnectionBalancer.
8587
func newNodeConnectionBalancer[P balancePoolConn[C], C balanceConn](pool balanceablePool[P, C], healthTracker *NodeHealthTracker, interval time.Duration) *nodeConnectionBalancer[P, C] {
86-
seed := int64(new(maphash.Hash).Sum64())
88+
seed := int64(0)
89+
for seed == 0 {
90+
// Sum64 returns a uint64, and safecast will return 0 if it's not castable,
91+
// which will happen about half the time (?). We just keep running it until
92+
// we get a seed that fits in the box.
93+
// Subtracting math.MaxInt64 should mean that we retain the entire range of
94+
// possible values.
95+
seed, _ = safecast.ToInt64(new(maphash.Hash).Sum64() - math.MaxInt64)
96+
}
8797
return &nodeConnectionBalancer[P, C]{
8898
ticker: time.NewTicker(interval),
8999
sem: semaphore.NewWeighted(1),
90100
healthTracker: healthTracker,
91101
pool: pool,
92102
seed: seed,
93-
rnd: rand.New(rand.NewSource(seed)),
103+
// nolint:gosec
104+
// use of non cryptographically secure random number generator is not concern here,
105+
// as it's used for shuffling the nodes to balance the connections when the number of
106+
// connections do not divide evenly.
107+
rnd: rand.New(rand.NewSource(seed)),
94108
}
95109
}
96110

@@ -104,18 +118,18 @@ func (p *nodeConnectionBalancer[P, C]) Prune(ctx context.Context) {
104118
case <-p.ticker.C:
105119
if p.sem.TryAcquire(1) {
106120
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
107-
p.pruneConnections(ctx)
121+
p.mustPruneConnections(ctx)
108122
cancel()
109123
p.sem.Release(1)
110124
}
111125
}
112126
}
113127
}
114128

115-
// pruneConnections prunes connections to nodes that have more than MaxConns/(# of nodes)
129+
// mustPruneConnections prunes connections to nodes that have more than MaxConns/(# of nodes)
116130
// This causes the pool to reconnect, which over time will lead to a balanced number of connections
117131
// across each node.
118-
func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
132+
func (p *nodeConnectionBalancer[P, C]) mustPruneConnections(ctx context.Context) {
119133
start := time.Now()
120134
defer func() {
121135
pruningTimeHistogram.WithLabelValues(p.pool.ID()).Observe(float64(time.Since(start).Milliseconds()))
@@ -142,7 +156,9 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
142156
}
143157
}
144158

145-
nodeCount := uint32(p.healthTracker.HealthyNodeCount())
159+
// It's highly unlikely that we'll ever have an overflow in
160+
// this context, so we cast directly.
161+
nodeCount, _ := safecast.ToUint32(p.healthTracker.HealthyNodeCount())
146162
if nodeCount == 0 {
147163
nodeCount = 1
148164
}
@@ -169,7 +185,7 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
169185
}
170186
p.healthTracker.RUnlock()
171187

172-
nodes := maps.Keys(connectionCounts)
188+
nodes := slices.Collect(maps.Keys(connectionCounts))
173189
slices.Sort(nodes)
174190

175191
// Shuffle nodes in place deterministically based on the initial seed.
@@ -198,7 +214,7 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
198214
// it's possible for the difference in connections between nodes to differ by up to
199215
// the number of nodes.
200216
if p.healthTracker.HealthyNodeCount() == 0 ||
201-
uint32(i) < p.pool.MaxConns()%uint32(p.healthTracker.HealthyNodeCount()) {
217+
i < int(p.pool.MaxConns())%p.healthTracker.HealthyNodeCount() {
202218
perNodeMax++
203219
}
204220

@@ -208,7 +224,7 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
208224
if count <= perNodeMax {
209225
continue
210226
}
211-
log.Ctx(ctx).Info().
227+
log.Ctx(ctx).Trace().
212228
Uint32("node", node).
213229
Uint32("poolmaxconns", p.pool.MaxConns()).
214230
Uint32("conncount", count).
@@ -220,18 +236,29 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
220236
if numToPrune > 1 {
221237
numToPrune >>= 1
222238
}
223-
if uint32(len(healthyConns[node])) < numToPrune {
224-
numToPrune = uint32(len(healthyConns[node]))
239+
240+
healthyNodeCount := mustEnsureUInt32(len(healthyConns[node]))
241+
if healthyNodeCount < numToPrune {
242+
numToPrune = healthyNodeCount
225243
}
226244
if numToPrune == 0 {
227245
continue
228246
}
229247

230248
for _, c := range healthyConns[node][:numToPrune] {
231-
log.Ctx(ctx).Debug().Str("pool", p.pool.ID()).Uint32("node", node).Msg("pruning connection")
249+
log.Ctx(ctx).Trace().Str("pool", p.pool.ID()).Uint32("node", node).Msg("pruning connection")
232250
p.pool.GC(c.Conn())
233251
}
234252

235-
log.Ctx(ctx).Info().Str("pool", p.pool.ID()).Uint32("node", node).Uint32("prunedCount", numToPrune).Msg("pruned connections")
253+
log.Ctx(ctx).Trace().Str("pool", p.pool.ID()).Uint32("node", node).Uint32("prunedCount", numToPrune).Msg("pruned connections")
254+
}
255+
}
256+
257+
// mustEnsureUInt32 ensures that the specified value can be represented as a uint32.
258+
func mustEnsureUInt32(value int) uint32 {
259+
ret, err := safecast.ToUint32(value)
260+
if err != nil {
261+
panic("specified value could not be cast to a uint32")
236262
}
263+
return ret
237264
}

pkg/balancer_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,19 @@ func TestNodeConnectionBalancerPrune(t *testing.T) {
150150

151151
p := newNodeConnectionBalancer[*FakePoolConn[*FakeConn], *FakeConn](pool, tracker, 1*time.Minute)
152152
p.seed = 0
153+
// nolint:gosec
154+
// G404 use of non cryptographically secure random number generator is not concern here,
155+
// as it's used for jittering the interval for health checks.
153156
p.rnd = rand.New(rand.NewSource(0))
154157

155158
for _, n := range tt.conns {
156159
pool.nodeForConn[NewFakeConn()] = n
157160
}
158161

159-
ctx, cancel := context.WithCancel(context.Background())
162+
ctx, cancel := context.WithCancel(t.Context())
160163
defer cancel()
161164

162-
p.pruneConnections(ctx)
165+
p.mustPruneConnections(ctx)
163166
require.Equal(t, len(tt.expectedGC), len(pool.gc))
164167
gcFromNodes := make([]uint32, 0, len(tt.expectedGC))
165168
for _, n := range pool.gc {

pkg/fake_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package pool
22

33
import (
44
"context"
5-
"hash/maphash"
6-
"strconv"
75
"sync"
6+
7+
"github.com/google/uuid"
88
)
99

1010
var (
@@ -38,13 +38,13 @@ type FakePool struct {
3838
sync.Mutex
3939
id string
4040
maxConns uint32
41-
gc map[*FakeConn]uint32
42-
nodeForConn map[*FakeConn]uint32
41+
gc map[*FakeConn]uint32 // GUARDED_BY(Mutex)
42+
nodeForConn map[*FakeConn]uint32 // GUARDED_BY(Mutex)
4343
}
4444

4545
func NewFakePool(maxConns uint32) *FakePool {
4646
return &FakePool{
47-
id: strconv.FormatUint(new(maphash.Hash).Sum64(), 16),
47+
id: uuid.NewString(),
4848
maxConns: maxConns,
4949
gc: make(map[*FakeConn]uint32, 0),
5050
nodeForConn: make(map[*FakeConn]uint32, 0),

pkg/health.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ func init() {
3333
type NodeHealthTracker struct {
3434
sync.RWMutex
3535
connConfig *pgx.ConnConfig
36-
healthyNodes map[uint32]struct{}
37-
nodesEverSeen map[uint32]*rate.Limiter
36+
healthyNodes map[uint32]struct{} // GUARDED_BY(RWMutex)
37+
nodesEverSeen map[uint32]*rate.Limiter // GUARDED_BY(RWMutex)
3838
newLimiter func() *rate.Limiter
3939
}
4040

@@ -58,6 +58,9 @@ func NewNodeHealthChecker(url string) (*NodeHealthTracker, error) {
5858
// Poll starts polling the cluster and recording the node IDs that it sees.
5959
func (t *NodeHealthTracker) Poll(ctx context.Context, interval time.Duration) {
6060
ticker := jitterbug.New(interval, jitterbug.Uniform{
61+
// nolint:gosec
62+
// G404 use of non cryptographically secure random number generator is not concern here,
63+
// as it's used for jittering the interval for health checks.
6164
Source: rand.New(rand.NewSource(time.Now().Unix())),
6265
Min: interval,
6366
})

0 commit comments

Comments
 (0)