Skip to content

Commit a4ab0b0

Browse files
Fix race conditions in the connection pool code.
Signed-off-by: Arthur Schreiber <[email protected]>
1 parent 5e31484 commit a4ab0b0

File tree

3 files changed

+51
-15
lines changed

3 files changed

+51
-15
lines changed

go/pools/smartconnpool/pool.go

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ type ConnPool[C Connection] struct {
129129
// workers is a waitgroup for all the currently running worker goroutines
130130
workers sync.WaitGroup
131131
close chan struct{}
132-
capacityMu sync.Mutex
132+
capacityMu sync.RWMutex
133133

134134
config struct {
135135
// connect is the callback to create a new connection for the pool
@@ -430,18 +430,40 @@ func (pool *ConnPool[C]) tryReturnConn(conn *Pooled[C]) bool {
430430
if pool.wait.tryReturnConn(conn) {
431431
return true
432432
}
433-
if pool.closeOnIdleLimitReached(conn) {
433+
434+
for {
435+
if pool.capacity.Load() < pool.active.Load() {
436+
conn.Close()
437+
pool.closedConn()
438+
return true
439+
}
440+
441+
if pool.closeOnIdleLimitReached(conn) {
442+
return false
443+
}
444+
445+
if !pool.capacityMu.TryRLock() {
446+
// If we can't get a read lock here, it means that the pool is being closed. Retry and check `capacity` again.
447+
continue
448+
}
449+
defer pool.capacityMu.RUnlock()
450+
451+
if pool.capacity.Load() < pool.active.Load() {
452+
conn.Close()
453+
pool.closedConn()
454+
return true
455+
}
456+
457+
connSetting := conn.Conn.Setting()
458+
if connSetting == nil {
459+
pool.clean.Push(conn)
460+
} else {
461+
stack := connSetting.bucket & stackMask
462+
pool.settings[stack].Push(conn)
463+
pool.freshSettingsStack.Store(int64(stack))
464+
}
434465
return false
435466
}
436-
connSetting := conn.Conn.Setting()
437-
if connSetting == nil {
438-
pool.clean.Push(conn)
439-
} else {
440-
stack := connSetting.bucket & stackMask
441-
pool.settings[stack].Push(conn)
442-
pool.freshSettingsStack.Store(int64(stack))
443-
}
444-
return false
445467
}
446468

447469
func (pool *ConnPool[C]) pop(stack *connStack[C]) *Pooled[C] {
@@ -595,7 +617,9 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) {
595617
// to other clients, wait until one of the connections is returned
596618
if conn == nil {
597619
start := time.Now()
598-
conn, err = pool.wait.waitForConn(ctx, nil)
620+
conn, err = pool.wait.waitForConn(ctx, nil, func() bool {
621+
return pool.close == nil || pool.capacity.Load() == 0
622+
})
599623
if err != nil {
600624
return nil, ErrTimeout
601625
}
@@ -652,7 +676,9 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) (
652676
// wait for one of them
653677
if conn == nil {
654678
start := time.Now()
655-
conn, err = pool.wait.waitForConn(ctx, setting)
679+
conn, err = pool.wait.waitForConn(ctx, setting, func() bool {
680+
return pool.close == nil || pool.capacity.Load() == 0
681+
})
656682
if err != nil {
657683
return nil, ErrTimeout
658684
}

go/pools/smartconnpool/waitlist.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,17 @@ type waitlist[C Connection] struct {
5050
// The returned connection may _not_ have the requested Setting. This function can
5151
// also return a `nil` connection even if our context has expired, if the pool has
5252
// forced an expiration of all waiters in the waitlist.
53-
func (wl *waitlist[C]) waitForConn(ctx context.Context, setting *Setting) (*Pooled[C], error) {
53+
func (wl *waitlist[C]) waitForConn(ctx context.Context, setting *Setting, isClosed func() bool) (*Pooled[C], error) {
5454
elem := wl.nodes.Get().(*list.Element[waiter[C]])
5555
elem.Value = waiter[C]{setting: setting, conn: nil, ctx: ctx}
5656

5757
wl.mu.Lock()
58+
if isClosed() {
59+
// if the pool is closed, we can't wait for a connection, so return an error
60+
wl.nodes.Put(elem)
61+
wl.mu.Unlock()
62+
return nil, ErrConnPoolClosed
63+
}
5864
// add ourselves as a waiter at the end of the waitlist
5965
wl.list.PushBackValue(elem)
6066
wl.mu.Unlock()

go/pools/smartconnpool/waitlist_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ func TestWaitlistExpireWithMultipleWaiters(t *testing.T) {
3838

3939
for i := 0; i < waiterCount; i++ {
4040
go func() {
41-
_, err := wait.waitForConn(ctx, nil)
41+
_, err := wait.waitForConn(ctx, nil, func() bool {
42+
// This function is called to check if the pool is closed.
43+
return ctx.Err() != nil
44+
})
45+
4246
if err != nil {
4347
expireCount.Add(1)
4448
}

0 commit comments

Comments
 (0)