Skip to content

Commit d4fe375

Browse files
authored
GODRIVER-1663 setting cancelConext to nil to avoid context being pinned (#436)
1 parent 3191cb1 commit d4fe375

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ type connection struct {
5151
connectContextMade chan struct{}
5252
canStream bool
5353
currentlyStreaming bool
54+
connectContextMutex sync.Mutex
5455

5556
// pool related fields
5657
pool *pool
@@ -108,7 +109,23 @@ func (c *connection) connect(ctx context.Context) {
108109
}
109110
defer close(c.connectDone)
110111

112+
c.connectContextMutex.Lock()
111113
ctx, c.cancelConnectContext = context.WithCancel(ctx)
114+
c.connectContextMutex.Unlock()
115+
116+
defer func() {
117+
var cancelFn context.CancelFunc
118+
119+
c.connectContextMutex.Lock()
120+
cancelFn = c.cancelConnectContext
121+
c.cancelConnectContext = nil
122+
c.connectContextMutex.Unlock()
123+
124+
if cancelFn != nil {
125+
cancelFn()
126+
}
127+
}()
128+
112129
close(c.connectContextMade)
113130

114131
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
@@ -195,7 +212,16 @@ func (c *connection) wait() error {
195212

196213
func (c *connection) closeConnectContext() {
197214
<-c.connectContextMade
198-
c.cancelConnectContext()
215+
var cancelFn context.CancelFunc
216+
217+
c.connectContextMutex.Lock()
218+
cancelFn = c.cancelConnectContext
219+
c.cancelConnectContext = nil
220+
c.connectContextMutex.Unlock()
221+
222+
if cancelFn != nil {
223+
cancelFn()
224+
}
199225
}
200226

201227
func transformNetworkError(originalError error, contextDeadlineUsed bool) error {

x/mongo/driver/topology/connection_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,22 @@ func TestConnection(t *testing.T) {
110110
assert.NotNil(t, err, "expected connect error %v, got nil", want)
111111
assert.Equal(t, want, got, "expected error %v, got %v", want, got)
112112
})
113+
t.Run("cancelConnectContext is nil after connect", func(t *testing.T) {
114+
conn, err := newConnection(context.Background(), address.Address(""))
115+
assert.Nil(t, err, "newConnection shouldn't error. got %v; want nil", err)
116+
var wg sync.WaitGroup
117+
wg.Add(1)
118+
119+
go func() {
120+
defer wg.Done()
121+
conn.connect(context.Background())
122+
assert.Nil(t, conn.cancelConnectContext, "expected nil, got context.CancelFunc")
123+
}()
124+
125+
conn.closeConnectContext()
126+
assert.Nil(t, conn.cancelConnectContext, "expected nil, got context.CancelFunc")
127+
wg.Wait()
128+
})
113129
})
114130
t.Run("writeWireMessage", func(t *testing.T) {
115131
t.Run("closed connection", func(t *testing.T) {

0 commit comments

Comments
 (0)