Skip to content

Commit 98f5b78

Browse files
committed
Fix race conditions and leaks in connection pool
GODRIVER-388 GODRIVER-389 GODRIVER-390 Change-Id: I1390c184beb379d4cc775ba07b5d222342dc957e
1 parent 0dfcbab commit 98f5b78

File tree

2 files changed

+147
-9
lines changed

2 files changed

+147
-9
lines changed

core/connection/pool.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ loop:
123123
select {
124124
case pc := <-p.conns:
125125
// This error would be overwritten by the semaphore
126-
_ = pc.Close()
126+
_ = p.closeConnection(pc)
127127
default:
128128
break loop
129129
}
@@ -155,24 +155,30 @@ func (p *pool) Get(ctx context.Context) (Connection, *description.Server, error)
155155
return nil, nil, ErrPoolClosed
156156
}
157157

158+
err := p.sem.Acquire(ctx, 1)
159+
if err != nil {
160+
return nil, nil, err
161+
}
162+
return p.get(ctx)
163+
}
164+
165+
func (p *pool) get(ctx context.Context) (Connection, *description.Server, error) {
158166
g := atomic.LoadUint64(&p.generation)
159167
select {
160168
case c := <-p.conns:
161169
if c.Expired() {
162170
go p.closeConnection(c)
163-
return p.Get(ctx)
171+
return p.get(ctx)
164172
}
165173

166-
return &acquired{Connection: c}, nil, nil
174+
return &acquired{Connection: c, sem: p.sem}, nil, nil
167175
case <-ctx.Done():
176+
p.sem.Release(1)
168177
return nil, nil, ctx.Err()
169178
default:
170-
err := p.sem.Acquire(ctx, 1)
171-
if err != nil {
172-
return nil, nil, err
173-
}
174179
c, desc, err := New(ctx, p.address, p.opts...)
175180
if err != nil {
181+
p.sem.Release(1)
176182
return nil, nil, err
177183
}
178184

@@ -185,20 +191,20 @@ func (p *pool) Get(ctx context.Context) (Connection, *description.Server, error)
185191
p.Lock()
186192
if atomic.LoadInt32(&p.connected) != connected {
187193
p.Unlock()
194+
p.sem.Release(1)
188195
p.closeConnection(pc)
189196
return nil, nil, ErrPoolClosed
190197
}
191198
defer p.Unlock()
192199
p.inflight[pc.id] = pc
193-
return &acquired{Connection: pc}, desc, nil
200+
return &acquired{Connection: pc, sem: p.sem}, desc, nil
194201
}
195202
}
196203

197204
func (p *pool) closeConnection(pc *pooledConnection) error {
198205
if !atomic.CompareAndSwapInt32(&pc.closed, 0, 1) {
199206
return nil
200207
}
201-
pc.p.sem.Release(1)
202208
p.Lock()
203209
delete(p.inflight, pc.id)
204210
p.Unlock()
@@ -241,6 +247,7 @@ func (pc *pooledConnection) Expired() bool {
241247
type acquired struct {
242248
Connection
243249

250+
sem *semaphore.Weighted
244251
sync.Mutex
245252
}
246253

@@ -269,6 +276,7 @@ func (a *acquired) Close() error {
269276
return nil
270277
}
271278
err := a.Connection.Close()
279+
a.sem.Release(1)
272280
a.Connection = nil
273281
return err
274282
}

core/connection/pool_test.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"errors"
66
"net"
7+
"runtime"
8+
"sync"
79
"sync/atomic"
810
"testing"
911
"time"
@@ -119,6 +121,8 @@ func TestPool(t *testing.T) {
119121
t.Errorf("Should have closed 3 connections, but didn't. got %d; want %d", d.lenclosed(), 3)
120122
}
121123
close(cleanup)
124+
err = conns[2].Close()
125+
noerr(t, err)
122126
ok := p.(*pool).sem.TryAcquire(int64(p.(*pool).capacity))
123127
if !ok {
124128
t.Errorf("clean shutdown should acquire and release semaphore, but semaphore still held")
@@ -478,6 +482,9 @@ func TestPool(t *testing.T) {
478482
err = p.Drain()
479483
noerr(t, err)
480484

485+
err = conns[1].Close()
486+
noerr(t, err)
487+
481488
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond)
482489
defer cancel()
483490
c, _, err = p.Get(ctx)
@@ -487,6 +494,129 @@ func TestPool(t *testing.T) {
487494
}
488495
close(cleanup)
489496
})
497+
t.Run("Cannot starve connection request", func(t *testing.T) {
498+
cleanup := make(chan struct{})
499+
address := bootstrapConnections(t, 3, func(nc net.Conn) {
500+
<-cleanup
501+
nc.Close()
502+
})
503+
d := newdialer(&net.Dialer{})
504+
p, err := NewPool(addr.Addr(address.String()), 1, 1, WithDialer(func(Dialer) Dialer { return d }))
505+
noerr(t, err)
506+
err = p.Connect(context.Background())
507+
noerr(t, err)
508+
conn, _, err := p.Get(context.Background())
509+
if d.lenopened() != 1 {
510+
t.Errorf("Should have opened 1 connections, but didn't. got %d; want %d", d.lenopened(), 1)
511+
}
512+
513+
var wg sync.WaitGroup
514+
515+
wg.Add(1)
516+
ch := make(chan struct{})
517+
go func() {
518+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
519+
defer cancel()
520+
ch <- struct{}{}
521+
_, _, err := p.Get(ctx)
522+
if err != nil {
523+
t.Errorf("Should not be able to starve connection request, but got error: %v", err)
524+
}
525+
wg.Done()
526+
}()
527+
<-ch
528+
runtime.Gosched()
529+
err = conn.Close()
530+
noerr(t, err)
531+
wg.Wait()
532+
close(cleanup)
533+
})
534+
t.Run("Does not leak permit from failure to dial connection", func(t *testing.T) {
535+
cleanup := make(chan struct{})
536+
address := bootstrapConnections(t, 0, func(nc net.Conn) {
537+
<-cleanup
538+
nc.Close()
539+
})
540+
close(cleanup)
541+
want := errors.New("dialing error")
542+
p, err := NewPool(
543+
addr.Addr(address.String()), 1, 2,
544+
WithDialer(
545+
func(Dialer) Dialer {
546+
return DialerFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
547+
return nil, want
548+
})
549+
}),
550+
)
551+
noerr(t, err)
552+
err = p.Connect(context.Background())
553+
noerr(t, err)
554+
_, _, err = p.Get(context.Background())
555+
if err != want {
556+
t.Errorf("Expected dial failure but got: %v", err)
557+
}
558+
ok := p.(*pool).sem.TryAcquire(int64(p.(*pool).capacity))
559+
if !ok {
560+
t.Errorf("Dial failure should not leak semaphore permit")
561+
} else {
562+
p.(*pool).sem.Release(int64(p.(*pool).capacity))
563+
}
564+
})
565+
t.Run("Does not leak permit from cancelled context", func(t *testing.T) {
566+
cleanup := make(chan struct{})
567+
address := bootstrapConnections(t, 1, func(nc net.Conn) {
568+
<-cleanup
569+
nc.Close()
570+
})
571+
close(cleanup)
572+
d := newdialer(&net.Dialer{})
573+
p, err := NewPool(addr.Addr(address.String()), 1, 2, WithDialer(func(Dialer) Dialer { return d }))
574+
noerr(t, err)
575+
err = p.Connect(context.Background())
576+
noerr(t, err)
577+
ctx, cancel := context.WithCancel(context.Background())
578+
cancel()
579+
_, _, err = p.Get(ctx)
580+
if err != context.Canceled {
581+
t.Errorf("Expected context cancelled error. got %v; want %v", err, context.Canceled)
582+
}
583+
ok := p.(*pool).sem.TryAcquire(int64(p.(*pool).capacity))
584+
if !ok {
585+
t.Errorf("Canceled context should not leak semaphore permit")
586+
} else {
587+
p.(*pool).sem.Release(int64(p.(*pool).capacity))
588+
}
589+
})
590+
t.Run("Get does not acquire multiple permits", func(t *testing.T) {
591+
cleanup := make(chan struct{})
592+
address := bootstrapConnections(t, 2, func(nc net.Conn) {
593+
<-cleanup
594+
nc.Close()
595+
})
596+
close(cleanup)
597+
d := newdialer(&net.Dialer{})
598+
p, err := NewPool(addr.Addr(address.String()), 1, 2, WithDialer(func(Dialer) Dialer { return d }))
599+
noerr(t, err)
600+
err = p.Connect(context.Background())
601+
noerr(t, err)
602+
c, _, err := p.Get(context.Background())
603+
noerr(t, err)
604+
err = c.Close()
605+
noerr(t, err)
606+
607+
p.Drain()
608+
609+
c, _, err = p.Get(context.Background())
610+
noerr(t, err)
611+
err = c.Close()
612+
noerr(t, err)
613+
ok := p.(*pool).sem.TryAcquire(int64(p.(*pool).capacity))
614+
if !ok {
615+
t.Errorf("Get should not acquire multiple permits (when expired conn in idle pool)")
616+
} else {
617+
p.(*pool).sem.Release(int64(p.(*pool).capacity))
618+
}
619+
})
490620
})
491621
t.Run("Connection", func(t *testing.T) {
492622
t.Run("Connection Close Does Not Error After Pool Is Disconnected", func(t *testing.T) {

0 commit comments

Comments
 (0)