@@ -3,15 +3,17 @@ package pool
33import (
44 "context"
55 "hash/maphash"
6+ "maps"
7+ "math"
68 "math/rand"
9+ "slices"
710 "strconv"
811 "time"
912
13+ "github.com/ccoveille/go-safecast"
1014 "github.com/jackc/pgx/v5"
1115 "github.com/jackc/pgx/v5/pgxpool"
1216 "github.com/prometheus/client_golang/prometheus"
13- "golang.org/x/exp/maps"
14- "golang.org/x/exp/slices"
1517 "golang.org/x/sync/semaphore"
1618
1719 log "github.com/rs/zerolog"
@@ -83,14 +85,26 @@ type nodeConnectionBalancer[P balancePoolConn[C], C balanceConn] struct {
8385// newNodeConnectionBalancer is generic over underlying connection types for
8486// testing purposes. Callers should use the exported NewNodeConnectionBalancer.
8587func newNodeConnectionBalancer [P balancePoolConn [C ], C balanceConn ](pool balanceablePool [P , C ], healthTracker * NodeHealthTracker , interval time.Duration ) * nodeConnectionBalancer [P , C ] {
86- seed := int64 (new (maphash.Hash ).Sum64 ())
88+ seed := int64 (0 )
89+ for seed == 0 {
90+ // Sum64 returns a uint64, and safecast will return 0 if it's not castable,
91+ // which will happen about half the time (?). We just keep running it until
92+ // we get a seed that fits in the box.
93+ // Subtracting math.MaxInt64 should mean that we retain the entire range of
94+ // possible values.
95+ seed , _ = safecast .ToInt64 (new (maphash.Hash ).Sum64 () - math .MaxInt64 )
96+ }
8797 return & nodeConnectionBalancer [P , C ]{
8898 ticker : time .NewTicker (interval ),
8999 sem : semaphore .NewWeighted (1 ),
90100 healthTracker : healthTracker ,
91101 pool : pool ,
92102 seed : seed ,
93- rnd : rand .New (rand .NewSource (seed )),
103+ // nolint:gosec
104+ // use of non cryptographically secure random number generator is not concern here,
105+ // as it's used for shuffling the nodes to balance the connections when the number of
106+ // connections do not divide evenly.
107+ rnd : rand .New (rand .NewSource (seed )),
94108 }
95109}
96110
@@ -104,18 +118,18 @@ func (p *nodeConnectionBalancer[P, C]) Prune(ctx context.Context) {
104118 case <- p .ticker .C :
105119 if p .sem .TryAcquire (1 ) {
106120 ctx , cancel := context .WithTimeout (ctx , 10 * time .Second )
107- p .pruneConnections (ctx )
121+ p .mustPruneConnections (ctx )
108122 cancel ()
109123 p .sem .Release (1 )
110124 }
111125 }
112126 }
113127}
114128
115- // pruneConnections prunes connections to nodes that have more than MaxConns/(# of nodes)
129+ // mustPruneConnections prunes connections to nodes that have more than MaxConns/(# of nodes)
116130// This causes the pool to reconnect, which over time will lead to a balanced number of connections
117131// across each node.
118- func (p * nodeConnectionBalancer [P , C ]) pruneConnections (ctx context.Context ) {
132+ func (p * nodeConnectionBalancer [P , C ]) mustPruneConnections (ctx context.Context ) {
119133 start := time .Now ()
120134 defer func () {
121135 pruningTimeHistogram .WithLabelValues (p .pool .ID ()).Observe (float64 (time .Since (start ).Milliseconds ()))
@@ -142,7 +156,9 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
142156 }
143157 }
144158
145- nodeCount := uint32 (p .healthTracker .HealthyNodeCount ())
159+ // It's highly unlikely that we'll ever have an overflow in
160+ // this context, so we cast directly.
161+ nodeCount , _ := safecast .ToUint32 (p .healthTracker .HealthyNodeCount ())
146162 if nodeCount == 0 {
147163 nodeCount = 1
148164 }
@@ -169,7 +185,7 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
169185 }
170186 p .healthTracker .RUnlock ()
171187
172- nodes := maps .Keys (connectionCounts )
188+ nodes := slices . Collect ( maps .Keys (connectionCounts ) )
173189 slices .Sort (nodes )
174190
175191 // Shuffle nodes in place deterministically based on the initial seed.
@@ -198,7 +214,7 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
198214 // it's possible for the difference in connections between nodes to differ by up to
199215 // the number of nodes.
200216 if p .healthTracker .HealthyNodeCount () == 0 ||
201- uint32 ( i ) < p .pool .MaxConns ()% uint32 ( p .healthTracker .HealthyNodeCount () ) {
217+ i < int ( p .pool .MaxConns ()) % p .healthTracker .HealthyNodeCount () {
202218 perNodeMax ++
203219 }
204220
@@ -208,7 +224,7 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
208224 if count <= perNodeMax {
209225 continue
210226 }
211- log .Ctx (ctx ).Info ().
227+ log .Ctx (ctx ).Trace ().
212228 Uint32 ("node" , node ).
213229 Uint32 ("poolmaxconns" , p .pool .MaxConns ()).
214230 Uint32 ("conncount" , count ).
@@ -220,18 +236,29 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
220236 if numToPrune > 1 {
221237 numToPrune >>= 1
222238 }
223- if uint32 (len (healthyConns [node ])) < numToPrune {
224- numToPrune = uint32 (len (healthyConns [node ]))
239+
240+ healthyNodeCount := mustEnsureUInt32 (len (healthyConns [node ]))
241+ if healthyNodeCount < numToPrune {
242+ numToPrune = healthyNodeCount
225243 }
226244 if numToPrune == 0 {
227245 continue
228246 }
229247
230248 for _ , c := range healthyConns [node ][:numToPrune ] {
231- log .Ctx (ctx ).Debug ().Str ("pool" , p .pool .ID ()).Uint32 ("node" , node ).Msg ("pruning connection" )
249+ log .Ctx (ctx ).Trace ().Str ("pool" , p .pool .ID ()).Uint32 ("node" , node ).Msg ("pruning connection" )
232250 p .pool .GC (c .Conn ())
233251 }
234252
235- log .Ctx (ctx ).Info ().Str ("pool" , p .pool .ID ()).Uint32 ("node" , node ).Uint32 ("prunedCount" , numToPrune ).Msg ("pruned connections" )
253+ log .Ctx (ctx ).Trace ().Str ("pool" , p .pool .ID ()).Uint32 ("node" , node ).Uint32 ("prunedCount" , numToPrune ).Msg ("pruned connections" )
254+ }
255+ }
256+
257+ // mustEnsureUInt32 ensures that the specified value can be represented as a uint32.
258+ func mustEnsureUInt32 (value int ) uint32 {
259+ ret , err := safecast .ToUint32 (value )
260+ if err != nil {
261+ panic ("specified value could not be cast to a uint32" )
236262 }
263+ return ret
237264}
0 commit comments