Skip to content

Commit a08c089

Browse files
author
iwysiu
committed
GODRIVER-1234 deadlock with minPoolSize
Change-Id: I478a49ae07c8106c88db520f42b97d27c1b2448d
1 parent bf97abe commit a08c089

File tree

5 files changed

+128
-29
lines changed

5 files changed

+128
-29
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,34 +45,23 @@ type connection struct {
4545
compressor wiremessage.CompressorID
4646
zliblevel int
4747
connected int32 // must be accessed using the sync/atomic package
48+
connectDone chan struct{}
49+
connectErr error
50+
config *connectionConfig
4851

4952
// pool related fields
5053
pool *pool
5154
poolID uint64
5255
generation uint64
5356
}
5457

55-
// newConnection handles the creation of a connection. It will dial, configure TLS, and perform
56-
// initialization handshakes.
58+
// newConnection handles the creation of a connection. It does not connect the connection.
5759
func newConnection(ctx context.Context, addr address.Address, opts ...ConnectionOption) (*connection, error) {
5860
cfg, err := newConnectionConfig(opts...)
5961
if err != nil {
6062
return nil, err
6163
}
6264

63-
nc, err := cfg.dialer.DialContext(ctx, addr.Network(), addr.String())
64-
if err != nil {
65-
return nil, ConnectionError{Wrapped: err, init: true}
66-
}
67-
68-
if cfg.tlsConfig != nil {
69-
tlsConfig := cfg.tlsConfig.Clone()
70-
nc, err = configureTLS(ctx, nc, addr, tlsConfig)
71-
if err != nil {
72-
return nil, ConnectionError{Wrapped: err, init: true}
73-
}
74-
}
75-
7665
var lifetimeDeadline time.Time
7766
if cfg.lifeTimeout > 0 {
7867
lifetimeDeadline = time.Now().Add(cfg.lifeTimeout)
@@ -82,32 +71,64 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
8271

8372
c := &connection{
8473
id: id,
85-
nc: nc,
8674
addr: addr,
8775
idleTimeout: cfg.idleTimeout,
8876
lifetimeDeadline: lifetimeDeadline,
8977
readTimeout: cfg.readTimeout,
9078
writeTimeout: cfg.writeTimeout,
79+
connectDone: make(chan struct{}),
80+
config: cfg,
81+
}
82+
atomic.StoreInt32(&c.connected, initialized)
83+
84+
return c, nil
85+
}
86+
87+
// connect handles the I/O for a connection. It will dial, configure TLS, and perform
88+
// initialization handshakes.
89+
func (c *connection) connect(ctx context.Context) {
90+
if !atomic.CompareAndSwapInt32(&c.connected, initialized, connected) {
91+
return
92+
}
93+
defer close(c.connectDone)
94+
95+
var err error
96+
c.nc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
97+
if err != nil {
98+
atomic.StoreInt32(&c.connected, disconnected)
99+
c.connectErr = ConnectionError{Wrapped: err, init: true}
100+
return
101+
}
102+
103+
if c.config.tlsConfig != nil {
104+
tlsConfig := c.config.tlsConfig.Clone()
105+
c.nc, err = configureTLS(ctx, c.nc, c.addr, tlsConfig)
106+
if err != nil {
107+
atomic.StoreInt32(&c.connected, disconnected)
108+
c.connectErr = ConnectionError{Wrapped: err, init: true}
109+
return
110+
}
91111
}
92-
atomic.StoreInt32(&c.connected, connected)
93112

94113
c.bumpIdleDeadline()
95114

96115
// running isMaster and authentication is handled by a handshaker on the configuration instance.
97-
if cfg.handshaker != nil {
98-
c.desc, err = cfg.handshaker.Handshake(ctx, c.addr, initConnection{c})
116+
if c.config.handshaker != nil {
117+
c.desc, err = c.config.handshaker.Handshake(ctx, c.addr, initConnection{c})
99118
if err != nil {
100119
if c.nc != nil {
101120
_ = c.nc.Close()
102121
}
103-
return nil, ConnectionError{Wrapped: err, init: true}
122+
atomic.StoreInt32(&c.connected, disconnected)
123+
c.connectErr = ConnectionError{Wrapped: err, init: true}
124+
return
104125
}
105-
if cfg.descCallback != nil {
106-
cfg.descCallback(c.desc)
126+
if c.config.descCallback != nil {
127+
c.config.descCallback(c.desc)
107128
}
108129
if len(c.desc.Compression) > 0 {
109130
clientMethodLoop:
110-
for _, method := range cfg.compressors {
131+
for _, method := range c.config.compressors {
111132
for _, serverMethod := range c.desc.Compression {
112133
if method != serverMethod {
113134
continue
@@ -119,16 +140,20 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
119140
case "zlib":
120141
c.compressor = wiremessage.CompressorZLib
121142
c.zliblevel = wiremessage.DefaultZlibLevel
122-
if cfg.zlibLevel != nil {
123-
c.zliblevel = *cfg.zlibLevel
143+
if c.config.zlibLevel != nil {
144+
c.zliblevel = *c.config.zlibLevel
124145
}
125146
}
126147
break clientMethodLoop
127148
}
128149
}
129150
}
130151
}
131-
return c, nil
152+
}
153+
154+
func (c *connection) connectWait() error {
155+
<-c.connectDone
156+
return c.connectErr
132157
}
133158

134159
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {

x/mongo/driver/topology/connection_test.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,27 @@ func TestConnection(t *testing.T) {
4545
t.Errorf("errors do not match. got %v; want %v", got, want)
4646
}
4747
})
48+
})
49+
t.Run("connect", func(t *testing.T) {
4850
t.Run("dialer error", func(t *testing.T) {
4951
err := errors.New("dialer error")
5052
var want error = ConnectionError{Wrapped: err}
51-
_, got := newConnection(context.Background(), address.Address(""), WithDialer(func(Dialer) Dialer {
53+
conn, got := newConnection(context.Background(), address.Address(""), WithDialer(func(Dialer) Dialer {
5254
return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, err })
5355
}))
56+
if got != nil {
57+
t.Errorf("newConnection shouldn't error. got %v; want nil", got)
58+
}
59+
conn.connect(context.Background())
60+
got = conn.connectWait()
5461
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
5562
t.Errorf("errors do not match. got %v; want %v", got, want)
5663
}
5764
})
5865
t.Run("handshaker error", func(t *testing.T) {
5966
err := errors.New("handshaker error")
6067
var want error = ConnectionError{Wrapped: err}
61-
_, got := newConnection(context.Background(), address.Address(""),
68+
conn, got := newConnection(context.Background(), address.Address(""),
6269
WithHandshaker(func(Handshaker) Handshaker {
6370
return HandshakerFunc(func(context.Context, address.Address, driver.Connection) (description.Server, error) {
6471
return description.Server{}, err
@@ -70,14 +77,19 @@ func TestConnection(t *testing.T) {
7077
})
7178
}),
7279
)
80+
if got != nil {
81+
t.Errorf("newConnection shouldn't error. got %v; want nil", got)
82+
}
83+
conn.connect(context.Background())
84+
got = conn.connectWait()
7385
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
7486
t.Errorf("errors do not match. got %v; want %v", got, want)
7587
}
7688
})
7789
t.Run("calls description callback", func(t *testing.T) {
7890
want := description.Server{Addr: address.Address("1.2.3.4:56789")}
7991
var got description.Server
80-
_, err := newConnection(context.Background(), address.Address(""),
92+
conn, err := newConnection(context.Background(), address.Address(""),
8193
withServerDescriptionCallback(func(desc description.Server) { got = desc },
8294
WithHandshaker(func(Handshaker) Handshaker {
8395
return HandshakerFunc(func(context.Context, address.Address, driver.Connection) (description.Server, error) {
@@ -92,6 +104,9 @@ func TestConnection(t *testing.T) {
92104
)...,
93105
)
94106
noerr(t, err)
107+
conn.connect(context.Background())
108+
err = conn.connectWait()
109+
noerr(t, err)
95110
if !cmp.Equal(got, want) {
96111
t.Errorf("Server descriptions do not match. got %v; want %v", got, want)
97112
}

x/mongo/driver/topology/pool.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ func (p *pool) connectionInitFunc() interface{} {
125125
if err != nil {
126126
return nil
127127
}
128+
129+
go c.connect(context.Background())
130+
128131
return c
129132
}
130133

@@ -321,6 +324,23 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
321324

322325
connVal := p.conns.Get()
323326
if c, ok := connVal.(*connection); ok && connVal != nil {
327+
// call connect if not connected
328+
if atomic.LoadInt32(&c.connected) == initialized {
329+
c.connect(ctx)
330+
}
331+
332+
err := c.connectWait()
333+
if err != nil {
334+
if p.monitor != nil {
335+
p.monitor.Event(&event.PoolEvent{
336+
Type: event.GetFailed,
337+
Address: p.address.String(),
338+
Reason: event.ReasonConnectionErrored,
339+
})
340+
}
341+
return nil, err
342+
}
343+
324344
if p.monitor != nil {
325345
p.monitor.Event(&event.PoolEvent{
326346
Type: event.GetSucceeded,
@@ -343,6 +363,21 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
343363
return nil, ctx.Err()
344364
default:
345365
c, reason, err := p.makeNewConnection(ctx)
366+
367+
if err != nil {
368+
if p.monitor != nil {
369+
p.monitor.Event(&event.PoolEvent{
370+
Type: event.GetFailed,
371+
Address: p.address.String(),
372+
Reason: reason,
373+
})
374+
}
375+
return nil, err
376+
}
377+
378+
c.connect(ctx)
379+
// wait for conn to be connected
380+
err = c.connectWait()
346381
if err != nil {
347382
if p.monitor != nil {
348383
p.monitor.Event(&event.PoolEvent{

x/mongo/driver/topology/server.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ const (
5858
disconnecting
5959
connected
6060
connecting
61+
initialized
6162
)
6263

6364
func connectionStateString(state int32) string {
@@ -70,6 +71,8 @@ func connectionStateString(state int32) string {
7071
return "Connected"
7172
case 3:
7273
return "Connecting"
74+
case 4:
75+
return "Initialized"
7376
}
7477

7578
return ""
@@ -493,6 +496,10 @@ func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
493496
}))
494497

495498
conn, err = newConnection(ctx, s.address, opts...)
499+
500+
conn.connect(ctx)
501+
502+
err := conn.connectWait()
496503
if err == nil {
497504
descPtr = &conn.desc
498505
}

x/mongo/driver/topology/topology_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
"go.mongodb.org/mongo-driver/x/mongo/driver"
1717
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
18+
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
1819
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
1920
)
2021

@@ -438,3 +439,19 @@ func TestSessionTimeout(t *testing.T) {
438439
}
439440
})
440441
}
442+
443+
func TestMinPoolSize(t *testing.T) {
444+
connStr := connstring.ConnString{
445+
Hosts: []string{"localhost:27017"},
446+
MinPoolSize: 10,
447+
MinPoolSizeSet: true,
448+
}
449+
topo, err := New(WithConnString(func(connstring.ConnString) connstring.ConnString { return connStr }))
450+
if err != nil {
451+
t.Errorf("topology.New shouldn't error. got: %v", err)
452+
}
453+
err = topo.Connect()
454+
if err != nil {
455+
t.Errorf("topology.Connect shouldn't error. got: %v", err)
456+
}
457+
}

0 commit comments

Comments
 (0)