Skip to content

Commit dde85c6

Browse files
committed
fix: race condition in ramp up
1 parent 0a50bce commit dde85c6

File tree

2 files changed

+65
-44
lines changed

2 files changed

+65
-44
lines changed

internal/requests/rampup.go

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package requests
33
import (
44
"context"
55
"fmt"
6+
"math"
67
"sync"
8+
"sync/atomic"
79
"time"
810

911
"golang.org/x/time/rate"
@@ -19,22 +21,38 @@ import (
1921
// - rampupPeriodDuration: The unit of each ramp-up period.
2022
// - maxInFlight: The maximum number of concurrent requests, used to protect the client and the server.
2123
// - requests: A slice of functions representing the API requests to be made.
22-
func RampUpAPIRequests( //nolint:gocognit,cyclop
24+
func RampUpAPIRequests( //nolint:cyclop
2325
ctx context.Context,
2426
minRPS, maxRPS, rampUpPeriod int, rampupPeriodDuration time.Duration, maxInFlight int,
2527
requests []func() error,
2628
) error {
27-
rpsIncrement := float64(maxRPS-minRPS) / float64(rampUpPeriod)
28-
limiter := rate.NewLimiter(rate.Limit(minRPS), 1)
29-
semaphore := make(chan struct{}, maxInFlight)
29+
var (
30+
rpsIncrement = float64(maxRPS-minRPS) / float64(rampUpPeriod)
31+
limiter = rate.NewLimiter(rate.Limit(minRPS), 1)
32+
semaphore = make(chan struct{}, maxInFlight)
33+
waitGroup sync.WaitGroup
34+
ticker = time.NewTicker(rampupPeriodDuration)
35+
requestIndex int32
36+
)
3037

31-
var waitGroup sync.WaitGroup
32-
33-
ticker := time.NewTicker(rampupPeriodDuration)
3438
defer ticker.Stop()
3539

36-
requestIndex := 0
37-
requestsLen := len(requests)
40+
if len(requests) > math.MaxInt32 {
41+
return fmt.Errorf( //nolint:err113
42+
"too many requests in ramp up: %d. max supported is %d", len(requests), math.MaxInt32,
43+
)
44+
}
45+
46+
requestsLen := int32(len(requests)) //nolint:gosec
47+
48+
worker := func(req func() error) {
49+
defer waitGroup.Done()
50+
defer func() { <-semaphore }()
51+
52+
if err := req(); err != nil {
53+
fmt.Printf("Error: %v\n", err)
54+
}
55+
}
3856

3957
for step := 0; step <= rampUpPeriod; step++ {
4058
select {
@@ -50,32 +68,19 @@ func RampUpAPIRequests( //nolint:gocognit,cyclop
5068
}
5169

5270
for i := 0; i < int(limiter.Limit()); i++ { //nolint:intrange
53-
if requestIndex >= requestsLen {
71+
idx := atomic.AddInt32(&requestIndex, 1) - 1
72+
if idx >= requestsLen {
5473
waitGroup.Wait()
5574

5675
return nil
5776
}
5877

78+
req := requests[idx]
5979
semaphore <- struct{}{}
6080

6181
waitGroup.Add(1)
6282

63-
go func(req func() error) {
64-
defer waitGroup.Done()
65-
defer func() { <-semaphore }()
66-
67-
if req == nil {
68-
fmt.Printf("Error: request function is nil, request %d out of %d\n", requestIndex, requestsLen)
69-
70-
return
71-
}
72-
73-
if err := req(); err != nil {
74-
fmt.Printf("Error: %v\n", err)
75-
}
76-
}(requests[requestIndex])
77-
78-
requestIndex++
83+
go worker(req)
7984
}
8085

8186
newRPS := rate.Limit(float64(minRPS) + rpsIncrement*float64(step))
@@ -99,26 +104,19 @@ func RampUpAPIRequests( //nolint:gocognit,cyclop
99104
}
100105

101106
for i := 0; i < int(limiter.Limit()); i++ { //nolint:intrange
102-
if requestIndex >= len(requests) {
107+
idx := atomic.AddInt32(&requestIndex, 1) - 1
108+
if idx >= requestsLen {
103109
waitGroup.Wait()
104110

105111
return nil
106112
}
107113

114+
req := requests[idx]
108115
semaphore <- struct{}{}
109116

110117
waitGroup.Add(1)
111118

112-
go func(req func() error) {
113-
defer waitGroup.Done()
114-
defer func() { <-semaphore }()
115-
116-
if err := req(); err != nil {
117-
fmt.Printf("Error: %v\n", err)
118-
}
119-
}(requests[requestIndex])
120-
121-
requestIndex++
119+
go worker(req)
122120
}
123121
}
124122
}

internal/requests/rampup_test.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package requests_test
33
import (
44
"context"
55
"errors"
6+
"sync"
67
"sync/atomic"
78
"testing"
89
"time"
@@ -13,12 +14,17 @@ import (
1314
func TestRampUpAPIRequests_Success(t *testing.T) {
1415
t.Parallel()
1516

16-
var callCount int32
17+
var (
18+
callCount int32
19+
mutex sync.Mutex
20+
)
1721

1822
requestsList := make([]func() error, 5)
1923
for i := range requestsList {
2024
requestsList[i] = func() error {
25+
mutex.Lock()
2126
atomic.AddInt32(&callCount, 1)
27+
mutex.Unlock()
2228

2329
return nil
2430
}
@@ -29,7 +35,7 @@ func TestRampUpAPIRequests_Success(t *testing.T) {
2935
t.Fatalf("expected no error, got %v, %v", err, callCount)
3036
}
3137

32-
if atomic.LoadInt32(&callCount) != int32(len(requestsList)) { //nolint:gosec
38+
if callCount != int32(len(requestsList)) { //nolint:gosec
3339
t.Fatalf("expected %d calls, got %d", len(requestsList), callCount)
3440
}
3541
}
@@ -40,12 +46,17 @@ func TestRampUpAPIRequests_RampUpRate(t *testing.T) {
4046
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
4147
defer cancel()
4248

43-
var callCount int32
49+
var (
50+
callCount int32
51+
mutex sync.Mutex
52+
)
4453

4554
requestsList := make([]func() error, 10)
4655
for i := range requestsList {
4756
requestsList[i] = func() error {
57+
mutex.Lock()
4858
atomic.AddInt32(&callCount, 1)
59+
mutex.Unlock()
4960

5061
return nil
5162
}
@@ -72,12 +83,17 @@ func TestRampUpAPIRequests_ContextCancelled(t *testing.T) {
7283
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
7384
defer cancel()
7485

75-
var callCount int32
86+
var (
87+
callCount int32
88+
mutex sync.Mutex
89+
)
7690

7791
requestsList := make([]func() error, 100)
7892
for i := range requestsList {
7993
requestsList[i] = func() error {
94+
mutex.Lock()
8095
atomic.AddInt32(&callCount, 1)
96+
mutex.Unlock()
8197
time.Sleep(100 * time.Millisecond)
8298

8399
return nil
@@ -89,7 +105,7 @@ func TestRampUpAPIRequests_ContextCancelled(t *testing.T) {
89105
t.Fatalf("expected error, got nil")
90106
}
91107

92-
if atomic.LoadInt32(&callCount) == int32(len(requestsList)) { //nolint:gosec
108+
if callCount == int32(len(requestsList)) { //nolint:gosec
93109
t.Fatalf("expected fewer than %d calls, got %d", len(requestsList), callCount)
94110
}
95111
}
@@ -100,16 +116,23 @@ func TestRampUpAPIRequests_RequestError(t *testing.T) {
100116
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
101117
defer cancel()
102118

103-
var callCount int32
119+
var (
120+
callCount int32
121+
mutex sync.Mutex
122+
)
104123

105124
requestsList := make([]func() error, 5)
106125
for i := range requestsList {
107126
requestsList[i] = func() error {
127+
mutex.Lock()
108128
atomic.AddInt32(&callCount, 1)
109129

110130
if callCount == 2 {
131+
mutex.Unlock()
132+
111133
return errors.New("request error")
112134
}
135+
mutex.Unlock()
113136

114137
return nil
115138
}
@@ -120,7 +143,7 @@ func TestRampUpAPIRequests_RequestError(t *testing.T) {
120143
t.Fatalf("expected no error, got %v", err)
121144
}
122145

123-
if atomic.LoadInt32(&callCount) != int32(len(requestsList)) { //nolint:gosec
146+
if callCount != int32(len(requestsList)) { //nolint:gosec
124147
t.Fatalf("expected %d calls, got %d", len(requestsList), callCount)
125148
}
126149
}

0 commit comments

Comments
 (0)