Skip to content

Commit 3759e9f

Browse files
author
Divjot Arora
authored
GODRIVER-1386 Use atomic.LoadInt32 in resource pool tests (#234)
1 parent d761f9f commit 3759e9f

File tree

1 file changed

+56
-54
lines changed

1 file changed

+56
-54
lines changed

x/mongo/driver/topology/resource_pool_test.go

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@
77
package topology
88

99
import (
10-
"github.com/stretchr/testify/require"
10+
"reflect"
1111
"sync/atomic"
1212
"testing"
1313
"time"
14+
15+
"github.com/google/go-cmp/cmp"
16+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1417
)
1518

19+
// rsrc is a mock resource used in resource pool tests.
20+
// This type should not be used other test files.
1621
type rsrc struct {
1722
closed bool
1823
}
@@ -35,8 +40,9 @@ func neverExpired(_ interface{}) bool {
3540

3641
// expiredCounter is used to implement an expiredFunc that will return true a fixed number of times.
3742
type expiredCounter struct {
38-
total, expiredCalled, closeCalled int32
39-
closeChan chan struct{}
43+
expiredCalled, closeCalled int32 // must be loaded/stored using atomic.*Int32 functions
44+
total int32
45+
closeChan chan struct{}
4046
}
4147

4248
func newExpiredCounter(total int32) expiredCounter {
@@ -48,17 +54,27 @@ func newExpiredCounter(total int32) expiredCounter {
4854

4955
func (ec *expiredCounter) expired(_ interface{}) bool {
5056
atomic.AddInt32(&ec.expiredCalled, 1)
51-
return ec.expiredCalled <= ec.total
57+
return ec.getExpiredCalled() <= ec.total
5258
}
5359

5460
func (ec *expiredCounter) close(_ interface{}) {
5561
atomic.AddInt32(&ec.closeCalled, 1)
56-
if ec.closeCalled == ec.total {
62+
if ec.getCloseCalled() == ec.total {
5763
ec.closeChan <- struct{}{}
5864
}
5965
}
6066

67+
func (ec *expiredCounter) getExpiredCalled() int32 {
68+
return atomic.LoadInt32(&ec.expiredCalled)
69+
}
70+
71+
func (ec *expiredCounter) getCloseCalled() int32 {
72+
return atomic.LoadInt32(&ec.closeCalled)
73+
}
74+
6175
func initPool(t *testing.T, minSize uint64, expFn expiredFunc, closeFn closeFunc, initFn initFunc, pruneInterval time.Duration) *resourcePool {
76+
t.Helper()
77+
6278
rpc := resourcePoolConfig{
6379
MinSize: minSize,
6480
MaintainInterval: pruneInterval,
@@ -67,70 +83,60 @@ func initPool(t *testing.T, minSize uint64, expFn expiredFunc, closeFn closeFunc
6783
InitFn: initFn,
6884
}
6985
rp, err := newResourcePool(rpc)
70-
require.NoError(t, err, "error creating new resource pool")
86+
assert.Nil(t, err, "error creating new resource pool: %v", err)
7187
rp.initialize()
7288
rp.maintainTimer.Reset(rp.maintainInterval)
7389
return rp
7490
}
7591

7692
func TestResourcePool(t *testing.T) {
93+
// register a cmp equality function for the rsrc type that will do a pointer comparison
94+
assert.RegisterOpts(reflect.TypeOf(&rsrc{}), cmp.Comparer(func(r1, r2 *rsrc) bool {
95+
return r1 == r2
96+
}))
97+
7798
t.Run("get", func(t *testing.T) {
7899
t.Run("remove stale resources", func(t *testing.T) {
79100
ec := newExpiredCounter(5)
80101
rp := initPool(t, 1, ec.expired, ec.close, initRsrc, time.Minute)
81102
rp.maintainTimer.Stop()
82103

83-
if got := rp.Get(); got != nil {
84-
t.Fatalf("resource mismatch; expected nil, got %v", got)
85-
}
86-
if rp.size != 0 {
87-
t.Fatalf("length mismatch; expected 0, got %d", rp.size)
88-
}
89-
if ec.expiredCalled != 1 {
90-
t.Fatalf("incorrect number of expire checks, expected 1, got %v", ec.expiredCalled)
91-
}
92-
if ec.closeCalled != 1 {
93-
t.Fatalf("incorrect number of closes called, expected 1, got %v", ec.closeCalled)
94-
}
104+
got := rp.Get()
105+
assert.Nil(t, got, "expected nil, got %v", got)
106+
assert.Equal(t, uint64(0), rp.size, "expected size 0, got %d", rp.size)
107+
108+
expiredCalled := ec.getExpiredCalled()
109+
assert.Equal(t, int32(1), expiredCalled, "expected expire to be called 1 time, got %v", expiredCalled)
110+
closeCalled := ec.getCloseCalled()
111+
assert.Equal(t, int32(1), closeCalled, "expected close to be called 1 time, got %v", closeCalled)
95112
})
96113
t.Run("recycle resources", func(t *testing.T) {
97114
rp := initPool(t, 1, neverExpired, closeRsrc, initRsrc, time.Minute)
98115
rp.maintainTimer.Stop()
99116
for i := 0; i < 5; i++ {
100117
got := rp.Get()
101-
if got == nil {
102-
t.Fatalf("resource mismatch; expected a resource but got nil")
103-
}
104-
if rp.size != 0 {
105-
t.Fatalf("length mismatch; expected 0, got %d", rp.size)
106-
}
118+
assert.NotNil(t, got, "expected resource, got nil")
119+
assert.Equal(t, uint64(0), rp.size, "expected size 0, got %v", rp.size)
120+
107121
rp.Put(got)
108-
if rp.size != 1 {
109-
t.Fatalf("length mismatch; expected 1, got %d", rp.size)
110-
}
122+
assert.Equal(t, uint64(1), rp.size, "expected size 1, got %v", rp.size)
111123
}
112124
})
113125
})
114126
t.Run("Put", func(t *testing.T) {
115127
t.Run("returned resources are returned to front of pool", func(t *testing.T) {
116128
rp := initPool(t, 0, neverExpired, closeRsrc, initRsrc, time.Minute)
117129
ret := &rsrc{}
118-
if !rp.Put(ret) {
119-
t.Fatal("return value mismatch; expected true, got false")
120-
}
121-
if rp.size != 1 {
122-
t.Fatalf("length mismatch; expected 1, got %d", rp.size)
123-
}
124-
if headVal := rp.Get(); headVal != ret {
125-
t.Fatalf("resource mismatch; expected %v at head of pool, got %v", ret, headVal)
126-
}
130+
assert.True(t, rp.Put(ret), "expected Put to return true, got false")
131+
assert.Equal(t, uint64(1), rp.size, "expected size 1, got %v", rp.size)
132+
133+
headVal := rp.Get()
134+
assert.Equal(t, ret, headVal, "expected resource %v at head of pool, got %v", ret, headVal)
127135
})
128136
t.Run("stale resource not returned", func(t *testing.T) {
129137
rp := initPool(t, 1, alwaysExpired, closeRsrc, initRsrc, time.Minute)
130138
ret := &rsrc{}
131-
if rp.Put(ret) {
132-
t.Fatal("return value mismatch; expected false, got true")
133-
}
139+
assert.False(t, rp.Put(ret), "expected Put to return false, got true")
134140
})
135141
})
136142
t.Run("Prune", func(t *testing.T) {
@@ -141,16 +147,14 @@ func TestResourcePool(t *testing.T) {
141147
ret := &rsrc{}
142148
_ = rp.Put(ret)
143149
}
150+
144151
rp.Maintain()
145-
if rp.size != 2 {
146-
t.Fatalf("length mismatch; expected 2, got %d", rp.size)
147-
}
148-
if ec.expiredCalled != 7 {
149-
t.Fatalf("count mismatch; expected ec.stale to be called 7 times, got %v", ec.expiredCalled)
150-
}
151-
if ec.closeCalled != 3 {
152-
t.Fatalf("count mismatch; expected ex.closeConnection to be called 3 times, got %v", ec.closeCalled)
153-
}
152+
assert.Equal(t, uint64(2), rp.size, "expected size 2, got %v", rp.size)
153+
154+
expiredCalled := ec.getExpiredCalled()
155+
assert.Equal(t, int32(7), expiredCalled, "expected expire to be called 7 times, got %v", expiredCalled)
156+
closeCalled := ec.getCloseCalled()
157+
assert.Equal(t, int32(3), closeCalled, "expected close to be called 3 times, got %v", closeCalled)
154158
})
155159
})
156160
t.Run("Background cleanup", func(t *testing.T) {
@@ -174,12 +178,10 @@ func TestResourcePool(t *testing.T) {
174178
t.Fatalf("value was not read on closeChan after 5 seconds")
175179
}
176180

177-
if atomic.LoadInt32(&ec.expiredCalled) != 5 {
178-
t.Fatalf("count mismatch; expected ec.stale to be called 5 times, got %v", ec.expiredCalled)
179-
}
180-
if atomic.LoadInt32(&ec.closeCalled) != 3 {
181-
t.Fatalf("count mismatch; expected ec.closeConnection to be called 5 times, got %v", ec.closeCalled)
182-
}
181+
expiredCalled := ec.getExpiredCalled()
182+
assert.Equal(t, int32(5), expiredCalled, "expected expire to be called 5 times, got %v", expiredCalled)
183+
closeCalled := ec.getCloseCalled()
184+
assert.Equal(t, int32(3), closeCalled, "expected close to be called 3 times, got %v", closeCalled)
183185
})
184186
})
185187
}

0 commit comments

Comments
 (0)