Skip to content

Commit 9bf2a79

Browse files
author
iwysiu
authored
GODRIVER-1411 fix data race between connection.connect and pool.disconnect (#284)
1 parent c36aa54 commit 9bf2a79

File tree

3 files changed

+85
-27
lines changed

3 files changed

+85
-27
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,24 @@ var globalConnectionID uint64 = 1
3030
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
3131

3232
type connection struct {
33-
id string
34-
nc net.Conn // When nil, the connection is closed.
35-
addr address.Address
36-
idleTimeout time.Duration
37-
idleDeadline atomic.Value // Stores a time.Time
38-
lifetimeDeadline time.Time
39-
readTimeout time.Duration
40-
writeTimeout time.Duration
41-
desc description.Server
42-
compressor wiremessage.CompressorID
43-
zliblevel int
44-
zstdLevel int
45-
connected int32 // must be accessed using the sync/atomic package
46-
connectDone chan struct{}
47-
connectErr error
48-
config *connectionConfig
33+
id string
34+
nc net.Conn // When nil, the connection is closed.
35+
addr address.Address
36+
idleTimeout time.Duration
37+
idleDeadline atomic.Value // Stores a time.Time
38+
lifetimeDeadline time.Time
39+
readTimeout time.Duration
40+
writeTimeout time.Duration
41+
desc description.Server
42+
compressor wiremessage.CompressorID
43+
zliblevel int
44+
zstdLevel int
45+
connected int32 // must be accessed using the sync/atomic package
46+
connectDone chan struct{}
47+
connectErr error
48+
config *connectionConfig
49+
cancelConnectContext context.CancelFunc
50+
connectContextMade chan struct{}
4951

5052
// pool related fields
5153
pool *pool
@@ -68,14 +70,15 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
6870
id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
6971

7072
c := &connection{
71-
id: id,
72-
addr: addr,
73-
idleTimeout: cfg.idleTimeout,
74-
lifetimeDeadline: lifetimeDeadline,
75-
readTimeout: cfg.readTimeout,
76-
writeTimeout: cfg.writeTimeout,
77-
connectDone: make(chan struct{}),
78-
config: cfg,
73+
id: id,
74+
addr: addr,
75+
idleTimeout: cfg.idleTimeout,
76+
lifetimeDeadline: lifetimeDeadline,
77+
readTimeout: cfg.readTimeout,
78+
writeTimeout: cfg.writeTimeout,
79+
connectDone: make(chan struct{}),
80+
config: cfg,
81+
connectContextMade: make(chan struct{}),
7982
}
8083
atomic.StoreInt32(&c.connected, initialized)
8184

@@ -85,12 +88,14 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
8588
// connect handles the I/O for a connection. It will dial, configure TLS, and perform
8689
// initialization handshakes.
8790
func (c *connection) connect(ctx context.Context) {
88-
8991
if !atomic.CompareAndSwapInt32(&c.connected, initialized, connected) {
9092
return
9193
}
9294
defer close(c.connectDone)
9395

96+
ctx, c.cancelConnectContext = context.WithCancel(ctx)
97+
close(c.connectContextMade)
98+
9499
var err error
95100
c.nc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
96101
if err != nil {
@@ -178,6 +183,11 @@ func (c *connection) wait() error {
178183
return c.connectErr
179184
}
180185

186+
func (c *connection) closeConnectContext() {
187+
<-c.connectContextMade
188+
c.cancelConnectContext()
189+
}
190+
181191
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
182192
var err error
183193
if atomic.LoadInt32(&c.connected) != connected {

x/mongo/driver/topology/pool.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ func connectionCloseFunc(v interface{}) {
117117
}
118118

119119
go func() {
120-
// wait for connection to finish trying to connect
121-
_ = c.wait()
122120
_ = c.pool.closeConnection(c)
123121
}()
124122
}
@@ -414,6 +412,11 @@ func (p *pool) closeConnection(c *connection) error {
414412
delete(p.opened, c.poolID)
415413
p.Unlock()
416414

415+
if atomic.LoadInt32(&c.connected) == connected {
416+
c.closeConnectContext()
417+
_ = c.wait() // Make sure that the connection has finished connecting
418+
}
419+
417420
if !atomic.CompareAndSwapInt32(&c.connected, connected, disconnected) {
418421
return nil // We're closing an already closed connection
419422
}

x/mongo/driver/topology/pool_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,51 @@ func TestPool(t *testing.T) {
175175
t.Errorf("Should have set the connection state on return. got %d; want %d", state, disconnected)
176176
}
177177
})
178+
t.Run("no race if connections are also connecting", func(t *testing.T) {
179+
cleanup := make(chan struct{})
180+
addr := bootstrapConnections(t, 3, func(nc net.Conn) {
181+
<-cleanup
182+
_ = nc.Close()
183+
})
184+
d := newdialer(&net.Dialer{})
185+
pc := poolConfig{
186+
Address: address.Address(addr.String()),
187+
}
188+
p, err := newPool(pc, WithDialer(func(Dialer) Dialer { return d }))
189+
noerr(t, err)
190+
err = p.connect()
191+
noerr(t, err)
192+
getDone := make(chan struct{})
193+
disconnectDone := make(chan struct{})
194+
_, err = p.get(context.Background())
195+
noerr(t, err)
196+
getCtx, getCancel := context.WithCancel(context.Background())
197+
defer getCancel()
198+
go func() {
199+
defer close(getDone)
200+
for {
201+
select {
202+
case <-disconnectDone:
203+
return
204+
default:
205+
_, _ = p.get(getCtx)
206+
noerr(t, err)
207+
time.Sleep(time.Microsecond)
208+
}
209+
}
210+
}()
211+
go func() {
212+
_, err := p.get(getCtx)
213+
noerr(t, err)
214+
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Microsecond)
215+
defer cancel()
216+
err = p.disconnect(ctx)
217+
noerr(t, err)
218+
close(disconnectDone)
219+
}()
220+
<-getDone
221+
close(cleanup)
222+
})
178223
})
179224
t.Run("connect", func(t *testing.T) {
180225
t.Run("can reconnect a disconnected pool", func(t *testing.T) {

0 commit comments

Comments
 (0)