Skip to content

Commit b89d991

Browse files
author
iwysiu
committed
GODRIVER-1298 fix panic from closing disconnected connection
Change-Id: I3ae9b0aca1c40c4a1ee3a672b249fabc4ec8b3ad
1 parent f5a2d84 commit b89d991

File tree

5 files changed

+70
-18
lines changed

5 files changed

+70
-18
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type connection struct {
3737
nc net.Conn // When nil, the connection is closed.
3838
addr address.Address
3939
idleTimeout time.Duration
40-
idleDeadline time.Time
40+
idleDeadline atomic.Value // Stores a time.Time
4141
lifetimeDeadline time.Time
4242
readTimeout time.Duration
4343
writeTimeout time.Duration
@@ -87,6 +87,7 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
8787
// connect handles the I/O for a connection. It will dial, configure TLS, and perform
8888
// initialization handshakes.
8989
func (c *connection) connect(ctx context.Context) {
90+
9091
if !atomic.CompareAndSwapInt32(&c.connected, initialized, connected) {
9192
return
9293
}
@@ -151,8 +152,10 @@ func (c *connection) connect(ctx context.Context) {
151152
}
152153
}
153154

154-
func (c *connection) connectWait() error {
155-
<-c.connectDone
155+
func (c *connection) wait() error {
156+
if c.connectDone != nil {
157+
<-c.connectDone
158+
}
156159
return c.connectErr
157160
}
158161

@@ -259,7 +262,11 @@ func (c *connection) close() error {
259262
return nil
260263
}
261264
if c.pool == nil {
262-
err := c.nc.Close()
265+
var err error
266+
267+
if c.nc != nil {
268+
err = c.nc.Close()
269+
}
263270
atomic.StoreInt32(&c.connected, disconnected)
264271
return err
265272
}
@@ -268,7 +275,8 @@ func (c *connection) close() error {
268275

269276
func (c *connection) expired() bool {
270277
now := time.Now()
271-
if !c.idleDeadline.IsZero() && now.After(c.idleDeadline) {
278+
idleDeadline, ok := c.idleDeadline.Load().(time.Time)
279+
if ok && now.After(idleDeadline) {
272280
return true
273281
}
274282

@@ -281,7 +289,7 @@ func (c *connection) expired() bool {
281289

282290
func (c *connection) bumpIdleDeadline() {
283291
if c.idleTimeout > 0 {
284-
c.idleDeadline = time.Now().Add(c.idleTimeout)
292+
c.idleDeadline.Store(time.Now().Add(c.idleTimeout))
285293
}
286294
}
287295

x/mongo/driver/topology/connection_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func TestConnection(t *testing.T) {
5757
t.Errorf("newConnection shouldn't error. got %v; want nil", got)
5858
}
5959
conn.connect(context.Background())
60-
got = conn.connectWait()
60+
got = conn.wait()
6161
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
6262
t.Errorf("errors do not match. got %v; want %v", got, want)
6363
}
@@ -81,7 +81,7 @@ func TestConnection(t *testing.T) {
8181
t.Errorf("newConnection shouldn't error. got %v; want nil", got)
8282
}
8383
conn.connect(context.Background())
84-
got = conn.connectWait()
84+
got = conn.wait()
8585
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
8686
t.Errorf("errors do not match. got %v; want %v", got, want)
8787
}
@@ -105,7 +105,7 @@ func TestConnection(t *testing.T) {
105105
)
106106
noerr(t, err)
107107
conn.connect(context.Background())
108-
err = conn.connectWait()
108+
err = conn.wait()
109109
noerr(t, err)
110110
if !cmp.Equal(got, want) {
111111
t.Errorf("Server descriptions do not match. got %v; want %v", got, want)

x/mongo/driver/topology/pool.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ func connectionCloseFunc(v interface{}) {
116116
return
117117
}
118118

119-
go func() { _ = c.pool.closeConnection(c) }()
119+
go func() {
120+
// wait for connection to finish trying to connect
121+
_ = c.wait()
122+
_ = c.pool.closeConnection(c)
123+
}()
120124
}
121125

122126
// connectionInitFunc returns an init function for the resource pool that will make new connections for this pool
@@ -329,7 +333,7 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
329333
c.connect(ctx)
330334
}
331335

332-
err := c.connectWait()
336+
err := c.wait()
333337
if err != nil {
334338
if p.monitor != nil {
335339
p.monitor.Event(&event.PoolEvent{
@@ -377,7 +381,7 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
377381

378382
c.connect(ctx)
379383
// wait for conn to be connected
380-
err = c.connectWait()
384+
err = c.wait()
381385
if err != nil {
382386
if p.monitor != nil {
383387
p.monitor.Event(&event.PoolEvent{
@@ -413,10 +417,14 @@ func (p *pool) closeConnection(c *connection) error {
413417
if !atomic.CompareAndSwapInt32(&c.connected, connected, disconnected) {
414418
return nil // We're closing an already closed connection
415419
}
416-
err := c.nc.Close()
417-
if err != nil {
418-
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to closeConnection net.Conn"}
420+
421+
if c.nc != nil {
422+
err := c.nc.Close()
423+
if err != nil {
424+
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to close net.Conn"}
425+
}
419426
}
427+
420428
return nil
421429
}
422430

x/mongo/driver/topology/pool_test.go

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ package topology
33
import (
44
"context"
55
"errors"
6-
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
76
"net"
8-
97
"sync/atomic"
108
"testing"
119
"time"
10+
11+
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
1212
)
1313

1414
func TestPool(t *testing.T) {
@@ -531,5 +531,41 @@ func TestPool(t *testing.T) {
531531
t.Errorf("Should not return connection to pool twice. got %d; want %d", p.conns.size, 1)
532532
}
533533
})
534+
t.Run("close does not panic if expires before connected", func(t *testing.T) {
535+
cleanup := make(chan struct{})
536+
defer close(cleanup)
537+
addr := bootstrapConnections(t, 3, func(nc net.Conn) {
538+
<-cleanup
539+
_ = nc.Close()
540+
})
541+
d := newSleepDialer(&net.Dialer{})
542+
pc := poolConfig{
543+
Address: address.Address(addr.String()),
544+
MinPoolSize: 1,
545+
}
546+
maintainInterval = time.Second
547+
p, err := newPool(pc, WithDialer(func(Dialer) Dialer { return d }),
548+
WithLifeTimeout(func(time.Duration) time.Duration { return 10 * time.Millisecond }),
549+
)
550+
maintainInterval = time.Minute
551+
noerr(t, err)
552+
err = p.connect()
553+
noerr(t, err)
554+
_, err = p.get(context.Background())
555+
noerr(t, err)
556+
})
534557
})
535558
}
559+
560+
type sleepDialer struct {
561+
Dialer
562+
}
563+
564+
func newSleepDialer(d Dialer) *sleepDialer {
565+
return &sleepDialer{d}
566+
}
567+
568+
func (d *sleepDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
569+
time.Sleep(5 * time.Second)
570+
return d.Dialer.DialContext(ctx, network, address)
571+
}

x/mongo/driver/topology/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
499499

500500
conn.connect(ctx)
501501

502-
err := conn.connectWait()
502+
err := conn.wait()
503503
if err == nil {
504504
descPtr = &conn.desc
505505
}

0 commit comments

Comments
 (0)