Skip to content

Commit 9c88c8c

Browse files
KritiRavDivjot Arora
authored andcommitted
GODRIVER-1663 setting cancelConext to nil to avoid context being pinned (#436)
1 parent 497705c commit 9c88c8c

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ type connection struct {
4848
config *connectionConfig
4949
cancelConnectContext context.CancelFunc
5050
connectContextMade chan struct{}
51+
connectContextMutex sync.Mutex
5152

5253
// pool related fields
5354
pool *pool
@@ -105,7 +106,23 @@ func (c *connection) connect(ctx context.Context) {
105106
}
106107
defer close(c.connectDone)
107108

109+
c.connectContextMutex.Lock()
108110
ctx, c.cancelConnectContext = context.WithCancel(ctx)
111+
c.connectContextMutex.Unlock()
112+
113+
defer func() {
114+
var cancelFn context.CancelFunc
115+
116+
c.connectContextMutex.Lock()
117+
cancelFn = c.cancelConnectContext
118+
c.cancelConnectContext = nil
119+
c.connectContextMutex.Unlock()
120+
121+
if cancelFn != nil {
122+
cancelFn()
123+
}
124+
}()
125+
109126
close(c.connectContextMade)
110127

111128
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
@@ -188,7 +205,16 @@ func (c *connection) wait() error {
188205

189206
func (c *connection) closeConnectContext() {
190207
<-c.connectContextMade
191-
c.cancelConnectContext()
208+
var cancelFn context.CancelFunc
209+
210+
c.connectContextMutex.Lock()
211+
cancelFn = c.cancelConnectContext
212+
c.cancelConnectContext = nil
213+
c.connectContextMutex.Unlock()
214+
215+
if cancelFn != nil {
216+
cancelFn()
217+
}
192218
}
193219

194220
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {

x/mongo/driver/topology/connection_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,22 @@ func TestConnection(t *testing.T) {
125125
assert.NotNil(t, err, "expected connect error %v, got nil", want)
126126
assert.Equal(t, want, got, "expected error %v, got %v", want, got)
127127
})
128+
t.Run("cancelConnectContext is nil after connect", func(t *testing.T) {
129+
conn, err := newConnection(context.Background(), address.Address(""))
130+
assert.Nil(t, err, "newConnection shouldn't error. got %v; want nil", err)
131+
var wg sync.WaitGroup
132+
wg.Add(1)
133+
134+
go func() {
135+
defer wg.Done()
136+
conn.connect(context.Background())
137+
assert.Nil(t, conn.cancelConnectContext, "expected nil, got context.CancelFunc")
138+
}()
139+
140+
conn.closeConnectContext()
141+
assert.Nil(t, conn.cancelConnectContext, "expected nil, got context.CancelFunc")
142+
wg.Wait()
143+
})
128144
})
129145
t.Run("writeWireMessage", func(t *testing.T) {
130146
t.Run("closed connection", func(t *testing.T) {

0 commit comments

Comments
 (0)