Skip to content

Commit b504c38

Browse files
benjirewisbenjirewis
authored andcommitted
GODRIVER-2533 Fix data race from NumberSessionsInProgress. (#1085)
1 parent 404feab commit b504c38

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed

mongo/client.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,10 @@ func (c *Client) Watch(ctx context.Context, pipeline interface{},
10991099
// NumberSessionsInProgress returns the number of sessions that have been started for this client but have not been
11001100
// closed (i.e. EndSession has not been called).
11011101
func (c *Client) NumberSessionsInProgress() int {
1102-
return c.sessionPool.CheckedOut()
1102+
// The underlying session pool uses an int64 for checkedOut to allow atomic
1103+
// access. We convert to an int here to maintain backward compatibility with
1104+
// older versions of the driver that did not atomically access checkedOut.
1105+
return int(c.sessionPool.CheckedOut())
11031106
}
11041107

11051108
func (c *Client) createBaseCursorOptions() driver.CursorOptions {

mongo/integration/sessions_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"bytes"
1111
"context"
1212
"reflect"
13+
"sync"
1314
"testing"
1415
"time"
1516

@@ -358,6 +359,35 @@ func TestSessions(t *testing.T) {
358359
assert.Equal(mt, err, mongo.ErrUnacknowledgedWrite,
359360
"expected ErrUnacknowledgedWrite on unacknowledged write in session, got %v", err)
360361
})
362+
363+
// Regression test for GODRIVER-2533. Note that this test assumes the race
364+
// detector is enabled (GODRIVER-2072).
365+
mt.Run("NumberSessionsInProgress data race", func(mt *mtest.T) {
366+
// Use two goroutines to execute a few simultaneous runs of NumberSessionsInProgress
367+
// and a basic collection operation (CountDocuments).
368+
var wg sync.WaitGroup
369+
wg.Add(2)
370+
371+
go func() {
372+
defer wg.Done()
373+
374+
for i := 0; i < 100; i++ {
375+
time.Sleep(100 * time.Microsecond)
376+
_ = mt.Client.NumberSessionsInProgress()
377+
}
378+
}()
379+
go func() {
380+
defer wg.Done()
381+
382+
for i := 0; i < 100; i++ {
383+
time.Sleep(100 * time.Microsecond)
384+
_, err := mt.Coll.CountDocuments(context.Background(), bson.D{})
385+
assert.Nil(mt, err, "CountDocument error: %v", err)
386+
}
387+
}()
388+
389+
wg.Wait()
390+
})
361391
}
362392

363393
func assertCollectionCount(mt *mtest.T, expectedCount int64) {

x/mongo/driver/session/session_pool.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package session
88

99
import (
1010
"sync"
11+
"sync/atomic"
1112

1213
"go.mongodb.org/mongo-driver/mongo/description"
1314
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
@@ -29,13 +30,14 @@ type topologyDescription struct {
2930

3031
// Pool is a pool of server sessions that can be reused.
3132
type Pool struct {
33+
// number of sessions checked out of pool (accessed atomically)
34+
checkedOut int64
35+
3236
descChan <-chan description.Topology
3337
head *Node
3438
tail *Node
3539
latestTopology topologyDescription
3640
mutex sync.Mutex // mutex to protect list and sessionTimeout
37-
38-
checkedOut int // number of sessions checked out of pool
3941
}
4042

4143
func (p *Pool) createServerSession() (*Server, error) {
@@ -44,7 +46,7 @@ func (p *Pool) createServerSession() (*Server, error) {
4446
return nil, err
4547
}
4648

47-
p.checkedOut++
49+
atomic.AddInt64(&p.checkedOut, 1)
4850
return s, nil
4951
}
5052

@@ -100,7 +102,7 @@ func (p *Pool) GetSession() (*Server, error) {
100102
p.head = p.head.next
101103
}
102104

103-
p.checkedOut++
105+
atomic.AddInt64(&p.checkedOut, 1)
104106
return session, nil
105107
}
106108

@@ -118,7 +120,7 @@ func (p *Pool) ReturnSession(ss *Server) {
118120
p.mutex.Lock()
119121
defer p.mutex.Unlock()
120122

121-
p.checkedOut--
123+
atomic.AddInt64(&p.checkedOut, -1)
122124
p.updateTimeout()
123125
// check sessions at end of queue for expired
124126
// stop checking after hitting the first valid session
@@ -185,6 +187,6 @@ func (p *Pool) String() string {
185187
}
186188

187189
// CheckedOut returns number of sessions checked out from pool.
188-
func (p *Pool) CheckedOut() int {
189-
return p.checkedOut
190+
func (p *Pool) CheckedOut() int64 {
191+
return atomic.LoadInt64(&p.checkedOut)
190192
}

0 commit comments

Comments
 (0)