@@ -3,7 +3,9 @@ package requests
33import (
44 "context"
55 "fmt"
6+ "math"
67 "sync"
8+ "sync/atomic"
79 "time"
810
911 "golang.org/x/time/rate"
@@ -19,22 +21,38 @@ import (
1921// - rampupPeriodDuration: The unit of each ramp-up period.
2022// - maxInFlight: The maximum number of concurrent requests, used to protect the client and the server.
2123// - requests: A slice of functions representing the API requests to be made.
22- func RampUpAPIRequests ( //nolint:gocognit, cyclop
24+ func RampUpAPIRequests ( //nolint:cyclop
2325 ctx context.Context ,
2426 minRPS , maxRPS , rampUpPeriod int , rampupPeriodDuration time.Duration , maxInFlight int ,
2527 requests []func () error ,
2628) error {
27- rpsIncrement := float64 (maxRPS - minRPS ) / float64 (rampUpPeriod )
28- limiter := rate .NewLimiter (rate .Limit (minRPS ), 1 )
29- semaphore := make (chan struct {}, maxInFlight )
29+ var (
30+ rpsIncrement = float64 (maxRPS - minRPS ) / float64 (rampUpPeriod )
31+ limiter = rate .NewLimiter (rate .Limit (minRPS ), 1 )
32+ semaphore = make (chan struct {}, maxInFlight )
33+ waitGroup sync.WaitGroup
34+ ticker = time .NewTicker (rampupPeriodDuration )
35+ requestIndex int32
36+ )
3037
31- var waitGroup sync.WaitGroup
32-
33- ticker := time .NewTicker (rampupPeriodDuration )
3438 defer ticker .Stop ()
3539
36- requestIndex := 0
37- requestsLen := len (requests )
40+ if len (requests ) > math .MaxInt32 {
41+ return fmt .Errorf ( //nolint:err113
42+ "too many requests in ramp up: %d. max supported is %d" , len (requests ), math .MaxInt32 ,
43+ )
44+ }
45+
46+ requestsLen := int32 (len (requests )) //nolint:gosec
47+
48+ worker := func (req func () error ) {
49+ defer waitGroup .Done ()
50+ defer func () { <- semaphore }()
51+
52+ if err := req (); err != nil {
53+ fmt .Printf ("Error: %v\n " , err )
54+ }
55+ }
3856
3957 for step := 0 ; step <= rampUpPeriod ; step ++ {
4058 select {
@@ -50,32 +68,19 @@ func RampUpAPIRequests( //nolint:gocognit,cyclop
5068 }
5169
5270 for i := 0 ; i < int (limiter .Limit ()); i ++ { //nolint:intrange
53- if requestIndex >= requestsLen {
71+ idx := atomic .AddInt32 (& requestIndex , 1 ) - 1
72+ if idx >= requestsLen {
5473 waitGroup .Wait ()
5574
5675 return nil
5776 }
5877
78+ req := requests [idx ]
5979 semaphore <- struct {}{}
6080
6181 waitGroup .Add (1 )
6282
63- go func (req func () error ) {
64- defer waitGroup .Done ()
65- defer func () { <- semaphore }()
66-
67- if req == nil {
68- fmt .Printf ("Error: request function is nil, request %d out of %d\n " , requestIndex , requestsLen )
69-
70- return
71- }
72-
73- if err := req (); err != nil {
74- fmt .Printf ("Error: %v\n " , err )
75- }
76- }(requests [requestIndex ])
77-
78- requestIndex ++
83+ go worker (req )
7984 }
8085
8186 newRPS := rate .Limit (float64 (minRPS ) + rpsIncrement * float64 (step ))
@@ -99,26 +104,19 @@ func RampUpAPIRequests( //nolint:gocognit,cyclop
99104 }
100105
101106 for i := 0 ; i < int (limiter .Limit ()); i ++ { //nolint:intrange
102- if requestIndex >= len (requests ) {
107+ idx := atomic .AddInt32 (& requestIndex , 1 ) - 1
108+ if idx >= requestsLen {
103109 waitGroup .Wait ()
104110
105111 return nil
106112 }
107113
114+ req := requests [idx ]
108115 semaphore <- struct {}{}
109116
110117 waitGroup .Add (1 )
111118
112- go func (req func () error ) {
113- defer waitGroup .Done ()
114- defer func () { <- semaphore }()
115-
116- if err := req (); err != nil {
117- fmt .Printf ("Error: %v\n " , err )
118- }
119- }(requests [requestIndex ])
120-
121- requestIndex ++
119+ go worker (req )
122120 }
123121 }
124122 }
0 commit comments