Skip to content

Commit bab1629

Browse files
author
Divjot Arora
committed
Add connected status to connection type.
GODRIVER-1173 Change-Id: I324fcba31c4065091a3449930fdbf88dc6b14afc
1 parent 5d7e5a2 commit bab1629

File tree

3 files changed

+20
-19
lines changed

3 files changed

+20
-19
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ type connection struct {
4747
desc description.Server
4848
compressor wiremessage.CompressorID
4949
zliblevel int
50+
connected int32 // must be accessed using the sync/atomic package
5051

5152
// pool related fields
5253
pool *pool
@@ -91,6 +92,7 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
9192
readTimeout: cfg.readTimeout,
9293
writeTimeout: cfg.writeTimeout,
9394
}
95+
atomic.StoreInt32(&c.connected, connected)
9496

9597
c.bumpIdleDeadline()
9698

@@ -134,7 +136,7 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
134136

135137
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
136138
var err error
137-
if c.nc == nil {
139+
if atomic.LoadInt32(&c.connected) != connected {
138140
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
139141
}
140142
select {
@@ -168,7 +170,7 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
168170

169171
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
170172
func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
171-
if c.nc == nil {
173+
if atomic.LoadInt32(&c.connected) != connected {
172174
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
173175
}
174176

@@ -231,12 +233,12 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e
231233
}
232234

233235
func (c *connection) close() error {
234-
if c.nc == nil {
236+
if atomic.LoadInt32(&c.connected) != connected {
235237
return nil
236238
}
237239
if c.pool == nil {
238240
err := c.nc.Close()
239-
c.nc = nil
241+
atomic.StoreInt32(&c.connected, disconnected)
240242
return err
241243
}
242244
return c.pool.close(c)
@@ -252,7 +254,7 @@ func (c *connection) expired() bool {
252254
return true
253255
}
254256

255-
return c.nc == nil
257+
return atomic.LoadInt32(&c.connected) == disconnected
256258
}
257259

258260
func (c *connection) bumpIdleDeadline() {

x/mongo/driver/topology/connection_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func TestConnection(t *testing.T) {
109109
t.Run("completed context", func(t *testing.T) {
110110
ctx, cancel := context.WithCancel(context.Background())
111111
cancel()
112-
conn := &connection{id: "foobar", nc: &net.TCPConn{}}
112+
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
113113
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to write"}
114114
got := conn.writeWireMessage(ctx, []byte{})
115115
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
@@ -144,7 +144,7 @@ func TestConnection(t *testing.T) {
144144
message: "failed to set write deadline",
145145
}
146146
tnc := &testNetConn{deadlineerr: errors.New("set writeDeadline error")}
147-
conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout}
147+
conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout, connected: connected}
148148
got := conn.writeWireMessage(ctx, []byte{})
149149
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
150150
t.Errorf("errors do not match. got %v; want %v", got, want)
@@ -160,7 +160,7 @@ func TestConnection(t *testing.T) {
160160
err := errors.New("Write error")
161161
want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "unable to write wire message to network"}
162162
tnc := &testNetConn{writeerr: err}
163-
conn := &connection{id: "foobar", nc: tnc}
163+
conn := &connection{id: "foobar", nc: tnc, connected: connected}
164164
got := conn.writeWireMessage(context.Background(), []byte{})
165165
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
166166
t.Errorf("errors do not match. got %v; want %v", got, want)
@@ -170,7 +170,7 @@ func TestConnection(t *testing.T) {
170170
}
171171
})
172172
tnc := &testNetConn{}
173-
conn := &connection{id: "foobar", nc: tnc}
173+
conn := &connection{id: "foobar", nc: tnc, connected: connected}
174174
want := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}
175175
err := conn.writeWireMessage(context.Background(), want)
176176
noerr(t, err)
@@ -192,7 +192,7 @@ func TestConnection(t *testing.T) {
192192
t.Run("completed context", func(t *testing.T) {
193193
ctx, cancel := context.WithCancel(context.Background())
194194
cancel()
195-
conn := &connection{id: "foobar", nc: &net.TCPConn{}}
195+
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
196196
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to read"}
197197
_, got := conn.readWireMessage(ctx, []byte{})
198198
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
@@ -227,7 +227,7 @@ func TestConnection(t *testing.T) {
227227
message: "failed to set read deadline",
228228
}
229229
tnc := &testNetConn{deadlineerr: errors.New("set readDeadline error")}
230-
conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout}
230+
conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout, connected: connected}
231231
_, got := conn.readWireMessage(ctx, []byte{})
232232
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
233233
t.Errorf("errors do not match. got %v; want %v", got, want)
@@ -242,7 +242,7 @@ func TestConnection(t *testing.T) {
242242
err := errors.New("Read error")
243243
want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "unable to decode message length"}
244244
tnc := &testNetConn{readerr: err}
245-
conn := &connection{id: "foobar", nc: tnc}
245+
conn := &connection{id: "foobar", nc: tnc, connected: connected}
246246
_, got := conn.readWireMessage(context.Background(), []byte{})
247247
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
248248
t.Errorf("errors do not match. got %v; want %v", got, want)
@@ -255,7 +255,7 @@ func TestConnection(t *testing.T) {
255255
err := errors.New("Read error")
256256
want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "unable to read full message"}
257257
tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}
258-
conn := &connection{id: "foobar", nc: tnc}
258+
conn := &connection{id: "foobar", nc: tnc, connected: connected}
259259
_, got := conn.readWireMessage(context.Background(), []byte{})
260260
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
261261
t.Errorf("errors do not match. got %v; want %v", got, want)
@@ -268,7 +268,7 @@ func TestConnection(t *testing.T) {
268268
want := []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}
269269
tnc := &testNetConn{buf: make([]byte, len(want))}
270270
copy(tnc.buf, want)
271-
conn := &connection{id: "foobar", nc: tnc}
271+
conn := &connection{id: "foobar", nc: tnc, connected: connected}
272272
got, err := conn.readWireMessage(context.Background(), nil)
273273
noerr(t, err)
274274
if !cmp.Equal(got, want) {

x/mongo/driver/topology/pool.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,12 @@ func (p *pool) close(c *connection) error {
182182
}
183183
p.Lock()
184184
delete(p.opened, c.poolID)
185-
nc := c.nc
186-
c.nc = nil
187185
p.Unlock()
188-
if nc == nil {
189-
return nil // We're closing an already closed connection.
186+
187+
if !atomic.CompareAndSwapInt32(&c.connected, connected, disconnected) {
188+
return nil // We're closing an already closed connection
190189
}
191-
err := nc.Close()
190+
err := c.nc.Close()
192191
if err != nil {
193192
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to close net.Conn"}
194193
}

0 commit comments

Comments
 (0)