Skip to content

Commit ccdbaf2

Browse files
author
Divjot Arora
committed
GODRIVER-1609 Ensure a single ConnectionClosed event is emitted per connection (#438)
1 parent d3eb353 commit ccdbaf2

File tree

6 files changed

+357
-59
lines changed

6 files changed

+357
-59
lines changed

x/mongo/driver/topology/CMAP_spec_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ type testInfo struct {
6868

6969
const cmapTestDir = "../../../../data/connection-monitoring-and-pooling/"
7070

71-
func TestCMAP(t *testing.T) {
71+
func TestCMAPSpec(t *testing.T) {
7272
for _, testFileName := range testHelpers.FindJSONFilesInDir(t, cmapTestDir) {
7373
t.Run(testFileName, func(t *testing.T) {
7474
runCMAPTest(t, testFileName)
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package topology
8+
9+
import (
10+
"context"
11+
"errors"
12+
"net"
13+
"testing"
14+
"time"
15+
16+
"go.mongodb.org/mongo-driver/event"
17+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
18+
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
19+
)
20+
21+
func TestCMAPProse(t *testing.T) {
22+
t.Run("created and closed events", func(t *testing.T) {
23+
created := make(chan *event.PoolEvent, 10)
24+
closed := make(chan *event.PoolEvent, 10)
25+
clearEvents := func() {
26+
for len(created) > 0 {
27+
<-created
28+
}
29+
for len(closed) > 0 {
30+
<-closed
31+
}
32+
}
33+
monitor := &event.PoolMonitor{
34+
Event: func(evt *event.PoolEvent) {
35+
switch evt.Type {
36+
case event.ConnectionCreated:
37+
created <- evt
38+
case event.ConnectionClosed:
39+
closed <- evt
40+
}
41+
},
42+
}
43+
getConfig := func() poolConfig {
44+
return poolConfig{
45+
PoolMonitor: monitor,
46+
}
47+
}
48+
assertConnectionCounts := func(t *testing.T, p *pool, numCreated, numClosed int) {
49+
t.Helper()
50+
51+
assert.Equal(t, numCreated, len(created), "expected %d creation events, got %d", numCreated, len(created))
52+
assert.Equal(t, numClosed, len(closed), "expected %d closed events, got %d", numClosed, len(closed))
53+
54+
netCount := numCreated - numClosed
55+
assert.Equal(t, netCount, len(p.opened), "expected %d connections in opened map, got %d", netCount,
56+
len(p.opened))
57+
}
58+
59+
t.Run("get", func(t *testing.T) {
60+
t.Run("errored connection exists in pool", func(t *testing.T) {
61+
// If a connection is created as part of minPoolSize maintenance and errors while connecting, get()
62+
// should report that error and publish an event.
63+
clearEvents()
64+
65+
var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
66+
return &testNetConn{writeerr: errors.New("write error")}, nil
67+
}
68+
69+
cfg := getConfig()
70+
cfg.MinPoolSize = 1
71+
connOpts := []ConnectionOption{
72+
WithDialer(func(Dialer) Dialer { return dialer }),
73+
WithHandshaker(func(Handshaker) Handshaker {
74+
return operation.NewIsMaster()
75+
}),
76+
}
77+
pool := createTestPool(t, cfg, connOpts...)
78+
79+
_, err := pool.get(context.Background())
80+
assert.NotNil(t, err, "expected get() error, got nil")
81+
assertConnectionCounts(t, pool, 1, 1)
82+
})
83+
t.Run("pool is empty", func(t *testing.T) {
84+
// If a new connection is created during get(), get() should report that error and publish an event.
85+
clearEvents()
86+
87+
var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
88+
return &testNetConn{writeerr: errors.New("write error")}, nil
89+
}
90+
91+
connOpts := []ConnectionOption{
92+
WithDialer(func(Dialer) Dialer { return dialer }),
93+
WithHandshaker(func(Handshaker) Handshaker {
94+
return operation.NewIsMaster()
95+
}),
96+
}
97+
pool := createTestPool(t, getConfig(), connOpts...)
98+
99+
_, err := pool.get(context.Background())
100+
assert.NotNil(t, err, "expected get() error, got nil")
101+
assertConnectionCounts(t, pool, 1, 1)
102+
})
103+
})
104+
t.Run("put", func(t *testing.T) {
105+
t.Run("errored connection", func(t *testing.T) {
106+
// If the connection being returned to the pool encountered a network error, it should be removed from
107+
// the pool and an event should be published.
108+
clearEvents()
109+
110+
var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
111+
return &testNetConn{writeerr: errors.New("write error")}, nil
112+
}
113+
114+
// We don't use the WithHandshaker option so the connection won't error during handshaking.
115+
connOpts := []ConnectionOption{
116+
WithDialer(func(Dialer) Dialer { return dialer }),
117+
}
118+
pool := createTestPool(t, getConfig(), connOpts...)
119+
120+
conn, err := pool.get(context.Background())
121+
assert.Nil(t, err, "get error: %v", err)
122+
123+
// Force a network error by writing to the connection.
124+
err = conn.writeWireMessage(context.Background(), nil)
125+
assert.NotNil(t, err, "expected writeWireMessage error, got nil")
126+
127+
err = pool.put(conn)
128+
assert.Nil(t, err, "put error: %v", err)
129+
130+
assertConnectionCounts(t, pool, 1, 1)
131+
evt := <-closed
132+
assert.Equal(t, event.ReasonConnectionErrored, evt.Reason, "expected reason %q, got %q",
133+
event.ReasonConnectionErrored, evt.Reason)
134+
})
135+
t.Run("expired connection", func(t *testing.T) {
136+
// If the connection being returned to the pool is expired, it should be removed from the pool and an
137+
// event should be published.
138+
clearEvents()
139+
140+
var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
141+
return &testNetConn{}, nil
142+
}
143+
144+
// We don't use the WithHandshaker option so the connection won't error during handshaking.
145+
// WithIdleTimeout must be used because the connection.idleTimeoutExpired() function only checks the
146+
// deadline if the idleTimeout option is greater than 0.
147+
connOpts := []ConnectionOption{
148+
WithDialer(func(Dialer) Dialer { return dialer }),
149+
WithIdleTimeout(func(time.Duration) time.Duration { return 1 * time.Second }),
150+
}
151+
pool := createTestPool(t, getConfig(), connOpts...)
152+
153+
conn, err := pool.get(context.Background())
154+
assert.Nil(t, err, "get error: %v", err)
155+
156+
// Set the idleDeadline to a time in the past to simulate expiration.
157+
pastTime := time.Now().Add(-10 * time.Second)
158+
conn.idleDeadline.Store(pastTime)
159+
160+
err = pool.put(conn)
161+
assert.Nil(t, err, "put error: %v", err)
162+
163+
assertConnectionCounts(t, pool, 1, 1)
164+
evt := <-closed
165+
assert.Equal(t, event.ReasonIdle, evt.Reason, "expected reason %q, got %q",
166+
event.ReasonIdle, evt.Reason)
167+
})
168+
})
169+
t.Run("disconnect", func(t *testing.T) {
170+
t.Run("connections returned gracefully", func(t *testing.T) {
171+
// If all connections are in the pool when disconnect is called, they should be closed gracefully and
172+
// events should be published.
173+
clearEvents()
174+
175+
numConns := 5
176+
var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
177+
return &testNetConn{}, nil
178+
}
179+
pool := createTestPool(t, getConfig(), WithDialer(func(Dialer) Dialer { return dialer }))
180+
181+
conns := checkoutConnections(t, pool, numConns)
182+
assertConnectionCounts(t, pool, numConns, 0)
183+
184+
// Return all connections to the pool and assert that none were closed by put().
185+
for i, c := range conns {
186+
err := pool.put(c)
187+
assert.Nil(t, err, "put error at index %d: %v", i, err)
188+
}
189+
assertConnectionCounts(t, pool, numConns, 0)
190+
191+
// Disconnect the pool and assert that a closed event is published for each connection.
192+
err := pool.disconnect(context.Background())
193+
assert.Nil(t, err, "disconnect error: %v", err)
194+
assertConnectionCounts(t, pool, numConns, numConns)
195+
196+
for len(closed) > 0 {
197+
evt := <-closed
198+
assert.Equal(t, event.ReasonPoolClosed, evt.Reason, "expected reason %q, got %q",
199+
event.ReasonPoolClosed, evt.Reason)
200+
}
201+
})
202+
t.Run("connections closed forcefully", func(t *testing.T) {
203+
// If some connections are still checked out when disconnect is called, they should be closed
204+
// forcefully and events should be published for them.
205+
clearEvents()
206+
207+
numConns := 5
208+
var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
209+
return &testNetConn{}, nil
210+
}
211+
pool := createTestPool(t, getConfig(), WithDialer(func(Dialer) Dialer { return dialer }))
212+
213+
conns := checkoutConnections(t, pool, numConns)
214+
assertConnectionCounts(t, pool, numConns, 0)
215+
216+
// Only return 2 of the connection.
217+
for i := 0; i < 2; i++ {
218+
err := pool.put(conns[i])
219+
assert.Nil(t, err, "put error at index %d: %v", i, err)
220+
}
221+
conns = conns[2:]
222+
assertConnectionCounts(t, pool, numConns, 0)
223+
224+
// Disconnect and assert that events are published for all conections.
225+
err := pool.disconnect(context.Background())
226+
assert.Nil(t, err, "disconnect error: %v", err)
227+
assertConnectionCounts(t, pool, numConns, numConns)
228+
229+
// Return the remaining connections and assert that the closed event count does not increase because
230+
// these connections have already been closed.
231+
for i, c := range conns {
232+
err := pool.put(c)
233+
assert.Nil(t, err, "put error at index %d: %v", i, err)
234+
}
235+
assertConnectionCounts(t, pool, numConns, numConns)
236+
237+
// Ensure all closed events have the correct reason.
238+
for len(closed) > 0 {
239+
evt := <-closed
240+
assert.Equal(t, event.ReasonPoolClosed, evt.Reason, "expected reason %q, got %q",
241+
event.ReasonPoolClosed, evt.Reason)
242+
}
243+
244+
})
245+
})
246+
})
247+
}
248+
249+
func createTestPool(t *testing.T, cfg poolConfig, opts ...ConnectionOption) *pool {
250+
t.Helper()
251+
252+
pool, err := newPool(cfg, opts...)
253+
assert.Nil(t, err, "newPool error: %v", err)
254+
err = pool.connect()
255+
assert.Nil(t, err, "connect error: %v", err)
256+
257+
return pool
258+
}
259+
260+
func checkoutConnections(t *testing.T, p *pool, numConns int) []*connection {
261+
conns := make([]*connection, 0, numConns)
262+
263+
for i := 0; i < numConns; i++ {
264+
conn, err := p.get(context.Background())
265+
assert.Nil(t, err, "get error at index %d: %v", i, err)
266+
conns = append(conns, conn)
267+
}
268+
269+
return conns
270+
}

x/mongo/driver/topology/connection.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ type connection struct {
5151
connectContextMutex sync.Mutex
5252

5353
// pool related fields
54-
pool *pool
55-
poolID uint64
56-
generation uint64
54+
pool *pool
55+
poolID uint64
56+
generation uint64
57+
expireReason string
5758
}
5859

5960
// newConnection handles the creation of a connection. It does not connect the connection.
@@ -328,7 +329,11 @@ func (c *connection) close() error {
328329
return err
329330
}
330331

331-
func (c *connection) expired() bool {
332+
func (c *connection) closed() bool {
333+
return atomic.LoadInt32(&c.connected) == disconnected
334+
}
335+
336+
func (c *connection) idleTimeoutExpired() bool {
332337
now := time.Now()
333338
if c.idleTimeout > 0 {
334339
idleDeadline, ok := c.idleDeadline.Load().(time.Time)
@@ -340,8 +345,7 @@ func (c *connection) expired() bool {
340345
if !c.lifetimeDeadline.IsZero() && now.After(c.lifetimeDeadline) {
341346
return true
342347
}
343-
344-
return atomic.LoadInt32(&c.connected) == disconnected
348+
return false
345349
}
346350

347351
func (c *connection) bumpIdleDeadline() {

0 commit comments

Comments
 (0)