Skip to content

Commit 977c9ae

Browse files
xoacBenjamin Rewis
authored andcommitted
Align atomically-accessed integer fields correctly (#723)
Uses Int64 instead of Int32 for atomically accessed integer fields and places those fields at the beginning of their encasing structs. This avoids alignment-based panics on 32-bit architectures.
1 parent b1393af commit 977c9ae

File tree

9 files changed

+73
-65
lines changed

9 files changed

+73
-65
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ var (
3939
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
4040

4141
type connection struct {
42+
// connected must be accessed using the atomic package and should be at the beginning of the struct.
43+
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
44+
// - suggested layout: https://go101.org/article/memory-layout.html
45+
connected int64
46+
4247
id string
4348
nc net.Conn // When nil, the connection is closed.
4449
addr address.Address
@@ -52,7 +57,6 @@ type connection struct {
5257
compressor wiremessage.CompressorID
5358
zliblevel int
5459
zstdLevel int
55-
connected int32 // must be accessed using the sync/atomic package
5660
connectDone chan struct{}
5761
connectErr error
5862
config *connectionConfig
@@ -97,13 +101,13 @@ func newConnection(addr address.Address, opts ...ConnectionOption) (*connection,
97101
if !c.config.loadBalanced {
98102
c.setGenerationNumber()
99103
}
100-
atomic.StoreInt32(&c.connected, initialized)
104+
atomic.StoreInt64(&c.connected, initialized)
101105

102106
return c, nil
103107
}
104108

105109
func (c *connection) processInitializationError(opCtx context.Context, err error) {
106-
atomic.StoreInt32(&c.connected, disconnected)
110+
atomic.StoreInt64(&c.connected, disconnected)
107111
if c.nc != nil {
108112
_ = c.nc.Close()
109113
}
@@ -138,7 +142,7 @@ func (c *connection) hasGenerationNumber() bool {
138142
// connect handles the I/O for a connection. It will dial, configure TLS, and perform
139143
// initialization handshakes.
140144
func (c *connection) connect(ctx context.Context) {
141-
if !atomic.CompareAndSwapInt32(&c.connected, initialized, connected) {
145+
if !atomic.CompareAndSwapInt64(&c.connected, initialized, connected) {
142146
return
143147
}
144148
defer close(c.connectDone)
@@ -345,7 +349,7 @@ func (c *connection) cancellationListenerCallback() {
345349

346350
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
347351
var err error
348-
if atomic.LoadInt32(&c.connected) != connected {
352+
if atomic.LoadInt64(&c.connected) != connected {
349353
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
350354
}
351355
select {
@@ -402,7 +406,7 @@ func (c *connection) write(ctx context.Context, wm []byte) (err error) {
402406

403407
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
404408
func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
405-
if atomic.LoadInt32(&c.connected) != connected {
409+
if atomic.LoadInt64(&c.connected) != connected {
406410
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
407411
}
408412

@@ -505,7 +509,7 @@ func (c *connection) read(ctx context.Context, dst []byte) (bytesRead []byte, er
505509

506510
func (c *connection) close() error {
507511
// Overwrite the connection state as the first step so only the first close call will execute.
508-
if !atomic.CompareAndSwapInt32(&c.connected, connected, disconnected) {
512+
if !atomic.CompareAndSwapInt64(&c.connected, connected, disconnected) {
509513
return nil
510514
}
511515

@@ -518,7 +522,7 @@ func (c *connection) close() error {
518522
}
519523

520524
func (c *connection) closed() bool {
521-
return atomic.LoadInt32(&c.connected) == disconnected
525+
return atomic.LoadInt64(&c.connected) == disconnected
522526
}
523527

524528
func (c *connection) idleTimeoutExpired() bool {

x/mongo/driver/topology/connection_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func TestConnection(t *testing.T) {
8181
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
8282
t.Errorf("errors do not match. got %v; want %v", got, want)
8383
}
84-
connState := atomic.LoadInt32(&conn.connected)
84+
connState := atomic.LoadInt64(&conn.connected)
8585
assert.Equal(t, disconnected, connState, "expected connection state %v, got %v", disconnected, connState)
8686
})
8787
t.Run("handshaker error", func(t *testing.T) {
@@ -109,7 +109,7 @@ func TestConnection(t *testing.T) {
109109
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
110110
t.Errorf("errors do not match. got %v; want %v", got, want)
111111
}
112-
connState := atomic.LoadInt32(&conn.connected)
112+
connState := atomic.LoadInt64(&conn.connected)
113113
assert.Equal(t, disconnected, connState, "expected connection state %v, got %v", disconnected, connState)
114114
})
115115
t.Run("calls error callback", func(t *testing.T) {
@@ -753,7 +753,7 @@ func TestConnection(t *testing.T) {
753753
conn.connect(context.Background())
754754
err = conn.wait()
755755
assert.NotNil(t, err, "expected handshake error from wait, got nil")
756-
connState := atomic.LoadInt32(&conn.connected)
756+
connState := atomic.LoadInt64(&conn.connected)
757757
assert.Equal(t, disconnected, connState, "expected connection state %v, got %v", disconnected, connState)
758758

759759
err = conn.close()

x/mongo/driver/topology/pool.go

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,23 @@ type checkOutResult struct {
5858

5959
// pool is a wrapper of resource pool that follows the CMAP spec for connection pools
6060
type pool struct {
61-
address address.Address
62-
opts []ConnectionOption
63-
conns *resourcePool // pool for non-checked out connections
64-
generation *poolGenerationMap
65-
monitor *event.PoolMonitor
66-
67-
// Must be accessed using the atomic package.
68-
connected int32
61+
// These fields must be accessed using the atomic package and should be at the beginning of the struct.
62+
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
63+
// - suggested layout: https://go101.org/article/memory-layout.html
64+
connected int64
6965
pinnedCursorConnections uint64
7066
pinnedTransactionConnections uint64
7167

7268
nextid uint64
7369
opened map[uint64]*connection // opened holds all of the currently open connections.
7470
sem *semaphore.Weighted
7571
sync.Mutex
72+
73+
address address.Address
74+
opts []ConnectionOption
75+
conns *resourcePool // pool for non-checked out connections
76+
generation *poolGenerationMap
77+
monitor *event.PoolMonitor
7678
}
7779

7880
// connectionExpiredFunc checks if a given connection is stale and should be removed from the resource pool
@@ -87,7 +89,7 @@ func connectionExpiredFunc(v interface{}) bool {
8789
}
8890

8991
switch {
90-
case atomic.LoadInt32(&c.pool.connected) != connected:
92+
case atomic.LoadInt64(&c.pool.connected) != connected:
9193
c.expireReason = event.ReasonPoolClosed
9294
case c.closed():
9395
// A connection would only be closed if it encountered a network error during an operation and closed itself.
@@ -208,7 +210,7 @@ func (p *pool) stale(c *connection) bool {
208210

209211
// connect puts the pool into the connected state, allowing it to be used and will allow items to begin being processed from the wait queue
210212
func (p *pool) connect() error {
211-
if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) {
213+
if !atomic.CompareAndSwapInt64(&p.connected, disconnected, connected) {
212214
return ErrPoolConnected
213215
}
214216
p.generation.connect()
@@ -218,7 +220,7 @@ func (p *pool) connect() error {
218220

219221
// disconnect disconnects the pool and closes all connections including those both in and out of the pool
220222
func (p *pool) disconnect(ctx context.Context) error {
221-
if !atomic.CompareAndSwapInt32(&p.connected, connected, disconnecting) {
223+
if !atomic.CompareAndSwapInt64(&p.connected, connected, disconnecting) {
222224
return ErrPoolDisconnected
223225
}
224226

@@ -267,7 +269,7 @@ func (p *pool) disconnect(ctx context.Context) error {
267269
_ = p.removeConnection(pc, event.ReasonPoolClosed)
268270
_ = p.closeConnection(pc) // We don't care about errors while closing the connection.
269271
}
270-
atomic.StoreInt32(&p.connected, disconnected)
272+
atomic.StoreInt64(&p.connected, disconnected)
271273
p.conns.clearTotal()
272274

273275
if p.monitor != nil {
@@ -301,7 +303,7 @@ func (p *pool) makeNewConnection() (*connection, string, error) {
301303
})
302304
}
303305

304-
if atomic.LoadInt32(&p.connected) != connected {
306+
if atomic.LoadInt64(&p.connected) != connected {
305307
// Manually publish a ConnectionClosed event here because the connection reference hasn't been stored and we
306308
// need to ensure each ConnectionCreated event has a corresponding ConnectionClosed event.
307309
if p.monitor != nil {
@@ -348,7 +350,7 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
348350
ctx = context.Background()
349351
}
350352

351-
if atomic.LoadInt32(&p.connected) != connected {
353+
if atomic.LoadInt64(&p.connected) != connected {
352354
if p.monitor != nil {
353355
p.monitor.Event(&event.PoolEvent{
354356
Type: event.GetFailed,
@@ -380,7 +382,7 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
380382
// This loop is so that we don't end up with more than maxPoolSize connections if p.conns.Maintain runs between
381383
// calling p.conns.Get() and making the new connection
382384
for {
383-
if atomic.LoadInt32(&p.connected) != connected {
385+
if atomic.LoadInt64(&p.connected) != connected {
384386
if p.monitor != nil {
385387
p.monitor.Event(&event.PoolEvent{
386388
Type: event.GetFailed,
@@ -395,7 +397,7 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
395397
connVal := p.conns.Get()
396398
if c, ok := connVal.(*connection); ok && connVal != nil {
397399
// call connect if not connected
398-
if atomic.LoadInt32(&c.connected) == initialized {
400+
if atomic.LoadInt64(&c.connected) == initialized {
399401
c.connect(ctx)
400402
}
401403

@@ -499,12 +501,12 @@ func (p *pool) closeConnection(c *connection) error {
499501
return ErrWrongPool
500502
}
501503

502-
if atomic.LoadInt32(&c.connected) == connected {
504+
if atomic.LoadInt64(&c.connected) == connected {
503505
c.closeConnectContext()
504506
_ = c.wait() // Make sure that the connection has finished connecting
505507
}
506508

507-
if !atomic.CompareAndSwapInt32(&c.connected, connected, disconnected) {
509+
if !atomic.CompareAndSwapInt64(&c.connected, connected, disconnected) {
508510
return nil // We're closing an already closed connection
509511
}
510512

x/mongo/driver/topology/pool_generation_counter.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ type generationStats struct {
2424
// load balancer, there is only one service ID: primitive.NilObjectID. For load-balanced deployments, each server behind
2525
// the load balancer will have a unique service ID.
2626
type poolGenerationMap struct {
27-
// state must be accessed using the atomic package.
28-
state int32
27+
// state must be accessed using the atomic package and should be at the beginning of the struct.
28+
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
29+
// - suggested layout: https://go101.org/article/memory-layout.html
30+
state int64
2931
generationMap map[primitive.ObjectID]*generationStats
3032

3133
sync.Mutex
@@ -40,11 +42,11 @@ func newPoolGenerationMap() *poolGenerationMap {
4042
}
4143

4244
func (p *poolGenerationMap) connect() {
43-
atomic.StoreInt32(&p.state, connected)
45+
atomic.StoreInt64(&p.state, connected)
4446
}
4547

4648
func (p *poolGenerationMap) disconnect() {
47-
atomic.StoreInt32(&p.state, disconnected)
49+
atomic.StoreInt64(&p.state, disconnected)
4850
}
4951

5052
// addConnection increments the connection count for the generation associated with the given service ID and returns the
@@ -100,7 +102,7 @@ func (p *poolGenerationMap) clear(serviceIDPtr *primitive.ObjectID) {
100102

101103
func (p *poolGenerationMap) stale(serviceIDPtr *primitive.ObjectID, knownGeneration uint64) bool {
102104
// If the map has been disconnected, all connections should be considered stale to ensure that they're closed.
103-
if atomic.LoadInt32(&p.state) == disconnected {
105+
if atomic.LoadInt64(&p.state) == disconnected {
104106
return true
105107
}
106108

x/mongo/driver/topology/pool_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ func TestPool(t *testing.T) {
193193
t.Errorf("Should have closed 1 connections, but didn't. got %d; want %d", d.lenclosed(), 1)
194194
}
195195
close(cleanup)
196-
state := atomic.LoadInt32(&p.connected)
196+
state := atomic.LoadInt64(&p.connected)
197197
if state != disconnected {
198198
t.Errorf("Should have set the connection state on return. got %d; want %d", state, disconnected)
199199
}
@@ -250,10 +250,10 @@ func TestPool(t *testing.T) {
250250
})
251251
t.Run("connect", func(t *testing.T) {
252252
t.Run("can reconnect a disconnected pool", func(t *testing.T) {
253-
assertGenerationMapState := func(t *testing.T, p *pool, expectedState int32) {
253+
assertGenerationMapState := func(t *testing.T, p *pool, expectedState int64) {
254254
t.Helper()
255255

256-
actualState := atomic.LoadInt32(&p.generation.state)
256+
actualState := atomic.LoadInt64(&p.generation.state)
257257
assert.Equal(t, expectedState, actualState, "expected generation map state %d, got %d", expectedState, actualState)
258258
}
259259

@@ -296,7 +296,7 @@ func TestPool(t *testing.T) {
296296
t.Errorf("Pool should have 0 total connections. got %d; want %d", p.conns.totalSize, 0)
297297
}
298298
close(cleanup)
299-
state := atomic.LoadInt32(&p.connected)
299+
state := atomic.LoadInt64(&p.connected)
300300
if state != disconnected {
301301
t.Errorf("Should have set the connection state on return. got %d; want %d", state, disconnected)
302302
}

x/mongo/driver/topology/server.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ func (ss *SelectedServer) Description() description.SelectedServer {
5656

5757
// These constants represent the connection states of a server.
5858
const (
59-
disconnected int32 = iota
59+
disconnected int64 = iota
6060
disconnecting
6161
connected
6262
connecting
6363
initialized
6464
)
6565

66-
func connectionStateString(state int32) string {
66+
func connectionStateString(state int64) string {
6767
switch state {
6868
case 0:
6969
return "Disconnected"
@@ -84,7 +84,7 @@ func connectionStateString(state int32) string {
8484
type Server struct {
8585
cfg *serverConfig
8686
address address.Address
87-
connectionstate int32
87+
connectionstate int64
8888

8989
// connection related fields
9090
pool *pool
@@ -195,7 +195,7 @@ func NewServer(addr address.Address, topologyID primitive.ObjectID, opts ...Serv
195195
// Connect initializes the Server by starting background monitoring goroutines.
196196
// This method must be called before a Server can be used.
197197
func (s *Server) Connect(updateCallback updateTopologyCallback) error {
198-
if !atomic.CompareAndSwapInt32(&s.connectionstate, disconnected, connected) {
198+
if !atomic.CompareAndSwapInt64(&s.connectionstate, disconnected, connected) {
199199
return ErrServerConnected
200200
}
201201

@@ -225,7 +225,7 @@ func (s *Server) Connect(updateCallback updateTopologyCallback) error {
225225
// any in flight read or write operations. If this method returns with no
226226
// errors, all connections associated with this Server have been closed.
227227
func (s *Server) Disconnect(ctx context.Context) error {
228-
if !atomic.CompareAndSwapInt32(&s.connectionstate, connected, disconnecting) {
228+
if !atomic.CompareAndSwapInt64(&s.connectionstate, connected, disconnecting) {
229229
return ErrServerClosed
230230
}
231231

@@ -246,7 +246,7 @@ func (s *Server) Disconnect(ctx context.Context) error {
246246
}
247247

248248
s.closewg.Wait()
249-
atomic.StoreInt32(&s.connectionstate, disconnected)
249+
atomic.StoreInt64(&s.connectionstate, disconnected)
250250

251251
return nil
252252
}
@@ -261,7 +261,7 @@ func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
261261
})
262262
}
263263

264-
if atomic.LoadInt32(&s.connectionstate) != connected {
264+
if atomic.LoadInt64(&s.connectionstate) != connected {
265265
return nil, ErrServerClosed
266266
}
267267

@@ -379,7 +379,7 @@ func (s *Server) SelectedDescription() description.SelectedServer {
379379
// updated server descriptions will be sent. The channel will have a buffer
380380
// size of one, and will be pre-populated with the current description.
381381
func (s *Server) Subscribe() (*ServerSubscription, error) {
382-
if atomic.LoadInt32(&s.connectionstate) != connected {
382+
if atomic.LoadInt64(&s.connectionstate) != connected {
383383
return nil, ErrSubscribeAfterClosed
384384
}
385385
ch := make(chan description.Server, 1)
@@ -577,7 +577,7 @@ func (s *Server) update() {
577577
// Perform the next check.
578578
desc, err := s.check()
579579
if err == errCheckCancelled {
580-
if atomic.LoadInt32(&s.connectionstate) != connected {
580+
if atomic.LoadInt64(&s.connectionstate) != connected {
581581
continue
582582
}
583583

@@ -843,7 +843,7 @@ func extractTopologyVersion(err error) *description.TopologyVersion {
843843
// String implements the Stringer interface.
844844
func (s *Server) String() string {
845845
desc := s.Description()
846-
connState := atomic.LoadInt32(&s.connectionstate)
846+
connState := atomic.LoadInt64(&s.connectionstate)
847847
str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
848848
s.address, desc.Kind, connectionStateString(connState))
849849
if len(desc.Tags) != 0 {

0 commit comments

Comments
 (0)