Skip to content

Commit 10abf0f

Browse files
authored
GODRIVER-2223 Fix data races caused by unsynchronized access to rand.Rand instances. (#808)
1 parent a21c36b commit 10abf0f

File tree

4 files changed

+61
-6
lines changed

4 files changed

+61
-6
lines changed

internal/randutil/randutil.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Package randutil provides common random number utilities.
2+
package randutil
3+
4+
import (
5+
"math/rand"
6+
"sync"
7+
)
8+
9+
// A LockedRand wraps a "math/rand".Rand and is safe to use from multiple goroutines.
10+
type LockedRand struct {
11+
mu sync.Mutex
12+
r *rand.Rand
13+
}
14+
15+
// NewLockedRand returns a new LockedRand that uses random values from src to generate other random
16+
// values. It is safe to use from multiple goroutines.
17+
func NewLockedRand(src rand.Source) *LockedRand {
18+
return &LockedRand{
19+
r: rand.New(src),
20+
}
21+
}
22+
23+
// Read generates len(p) random bytes and writes them into p. It always returns len(p) and a nil
24+
// error.
25+
func (lr *LockedRand) Read(p []byte) (int, error) {
26+
lr.mu.Lock()
27+
n, err := lr.r.Read(p)
28+
lr.mu.Unlock()
29+
return n, err
30+
}
31+
32+
// Intn returns, as an int, a non-negative pseudo-random number in the half-open interval [0,n). It
33+
// panics if n <= 0.
34+
func (lr *LockedRand) Intn(n int) int {
35+
lr.mu.Lock()
36+
x := lr.r.Intn(n)
37+
lr.mu.Unlock()
38+
return x
39+
}
40+
41+
// Shuffle pseudo-randomizes the order of elements. n is the number of elements. Shuffle panics if
42+
// n < 0. swap swaps the elements with indexes i and j.
43+
//
44+
// Note that Shuffle locks the LockedRand, so shuffling large collections may adversely affect other
45+
// concurrent calls. If many concurrent Shuffle and random value calls are required, consider using
46+
// the global "math/rand".Shuffle instead because it uses much more granular locking.
47+
func (lr *LockedRand) Shuffle(n int, swap func(i, j int)) {
48+
lr.mu.Lock()
49+
lr.r.Shuffle(n, swap)
50+
lr.mu.Unlock()
51+
}

x/mongo/driver/connstring/connstring.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ import (
1717
"time"
1818

1919
"go.mongodb.org/mongo-driver/internal"
20+
"go.mongodb.org/mongo-driver/internal/randutil"
2021
"go.mongodb.org/mongo-driver/mongo/writeconcern"
2122
"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
2223
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
2324
)
2425

25-
// random is a package-global pseudo-random number source.
26-
var random = rand.New(rand.NewSource(time.Now().UnixNano()))
26+
// random is a package-global pseudo-random number generator.
27+
var random = randutil.NewLockedRand(rand.NewSource(time.Now().UnixNano()))
2728

2829
// ParseAndValidate parses the provided URI into a ConnString object.
2930
// It check that all values are valid.

x/mongo/driver/topology/topology.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"go.mongodb.org/mongo-driver/bson/primitive"
2525
"go.mongodb.org/mongo-driver/event"
26+
"go.mongodb.org/mongo-driver/internal/randutil"
2627
"go.mongodb.org/mongo-driver/mongo/address"
2728
"go.mongodb.org/mongo-driver/mongo/description"
2829
"go.mongodb.org/mongo-driver/x/mongo/driver"
@@ -48,8 +49,8 @@ var ErrServerSelectionTimeout = errors.New("server selection timeout")
4849
// MonitorMode represents the way in which a server is monitored.
4950
type MonitorMode uint8
5051

51-
// random is a package-global pseudo-random number source.
52-
var random = rand.New(rand.NewSource(time.Now().UnixNano()))
52+
// random is a package-global pseudo-random number generator.
53+
var random = randutil.NewLockedRand(rand.NewSource(time.Now().UnixNano()))
5354

5455
// These constants are the available monitoring modes.
5556
const (

x/mongo/driver/uuid/uuid.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ import (
1010
"io"
1111
"math/rand"
1212
"time"
13+
14+
"go.mongodb.org/mongo-driver/internal/randutil"
1315
)
1416

1517
// UUID represents a UUID.
1618
type UUID [16]byte
1719

18-
// random is a package-global pseudo-random number source.
19-
var random = rand.New(rand.NewSource(time.Now().UnixNano()))
20+
// random is a package-global pseudo-random number generator.
21+
var random = randutil.NewLockedRand(rand.NewSource(time.Now().UnixNano()))
2022

2123
// New returns a random UUIDv4. It uses a "math/rand" pseudo-random number generator seeded with the
2224
// package initialization time.

0 commit comments

Comments
 (0)