Skip to content
This repository was archived by the owner on Mar 5, 2023. It is now read-only.

Commit 608dd63

Browse files
author
jonas747
committed
more rliable ratelimit bucket releases
1 parent 1f2f1d9 commit 608dd63

File tree

3 files changed

+101
-80
lines changed

3 files changed

+101
-80
lines changed

ratelimit.go

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ type RateLimiter struct {
2121
sync.Mutex
2222
global *int64
2323
buckets map[string]*Bucket
24-
globalRateLimit time.Duration
2524
customRateLimits []*customRateLimit
2625

2726
MaxConcurrentRequests int
@@ -61,9 +60,10 @@ func (r *RateLimiter) GetBucket(key string) *Bucket {
6160
}
6261

6362
b := &Bucket{
64-
Remaining: 1,
65-
Key: key,
66-
global: r.global,
63+
Remaining: 1,
64+
Key: key,
65+
global: r.global,
66+
lockCounter: new(int64),
6767
}
6868

6969
if r.MaxConcurrentRequests > 0 {
@@ -89,7 +89,7 @@ func (r *RateLimiter) GetWaitTime(b *Bucket, minRemaining int) time.Duration {
8989

9090
wait := time.Duration(0)
9191
if b.Remaining < minRemaining && b.reset.After(time.Now()) {
92-
wait = b.reset.Sub(time.Now())
92+
wait = time.Until(b.reset)
9393
}
9494

9595
// Check for global ratelimits
@@ -114,12 +114,14 @@ func (r *RateLimiter) GetWaitTime(b *Bucket, minRemaining int) time.Duration {
114114
}
115115

116116
// LockBucket Locks until a request can be made
117-
func (r *RateLimiter) LockBucket(bucketID string) *Bucket {
118-
return r.LockBucketObject(r.GetBucket(bucketID))
117+
func (r *RateLimiter) LockBucket(bucketID string) (b *Bucket, lockID int64) {
118+
bucket := r.GetBucket(bucketID)
119+
id := r.LockBucketObject(bucket)
120+
return bucket, id
119121
}
120122

121123
// LockBucketObject Locks an already resolved bucket until a request can be made
122-
func (r *RateLimiter) LockBucketObject(b *Bucket) *Bucket {
124+
func (r *RateLimiter) LockBucketObject(b *Bucket) (lockID int64) {
123125
b.Lock()
124126

125127
if wait := r.GetWaitTime(b, 1); wait > 0 {
@@ -131,7 +133,7 @@ func (r *RateLimiter) LockBucketObject(b *Bucket) *Bucket {
131133
// sleep until were below the maximum
132134
for {
133135
numNow := atomic.AddInt32(r.numConcurrentLocks, 1)
134-
if int(numNow) >= r.MaxConcurrentRequests {
136+
if int(numNow) > r.MaxConcurrentRequests {
135137
atomic.AddInt32(r.numConcurrentLocks, -1)
136138
didWaitForMaxCCR = true
137139
time.Sleep(time.Millisecond * 25)
@@ -149,7 +151,9 @@ func (r *RateLimiter) LockBucketObject(b *Bucket) *Bucket {
149151
}
150152

151153
b.Remaining--
152-
return b
154+
155+
counter := atomic.AddInt64(b.lockCounter, 1)
156+
return counter
153157
}
154158

155159
func (r *RateLimiter) SetGlobalTriggered(to time.Time) {
@@ -161,28 +165,37 @@ type Bucket struct {
161165
sync.Mutex
162166
Key string
163167
Remaining int
164-
limit int
165168
reset time.Time
166169
global *int64
167170
numConcurrentLocks *int32
168171

169172
lastReset time.Time
170173
customRateLimit *customRateLimit
171174
Userdata interface{}
175+
176+
lockCounter *int64
172177
}
173178

174179
// Release unlocks the bucket and reads the headers to update the buckets ratelimit info
175180
// and locks up the whole thing in case if there's a global ratelimit.
176-
func (b *Bucket) Release(headers http.Header) error {
181+
func (b *Bucket) Release(headers http.Header, lockCounter int64) error {
182+
if atomic.LoadInt64(b.lockCounter) != lockCounter {
183+
// attempted double unlock
184+
return nil
185+
}
186+
177187
defer b.Unlock()
178188

189+
// make sure that we can no longer unlock with the same ID
190+
atomic.AddInt64(b.lockCounter, 1)
191+
179192
if b.numConcurrentLocks != nil {
180193
atomic.AddInt32(b.numConcurrentLocks, -1)
181194
}
182195

183196
// Check if the bucket uses a custom ratelimiter
184197
if rl := b.customRateLimit; rl != nil {
185-
if time.Now().Sub(b.lastReset) >= rl.reset {
198+
if time.Since(b.lastReset) >= rl.reset {
186199
b.Remaining = rl.requests - 1
187200
b.lastReset = time.Now()
188201
}

ratelimit_test.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func TestRatelimitReset(t *testing.T) {
1212
rl := NewRatelimiter()
1313

1414
sendReq := func(endpoint string) {
15-
bucket := rl.LockBucket(endpoint)
15+
bucket, id := rl.LockBucket(endpoint)
1616

1717
headers := http.Header(make(map[string][]string))
1818

@@ -21,7 +21,7 @@ func TestRatelimitReset(t *testing.T) {
2121
headers.Set("X-RateLimit-Reset-After", "2")
2222
headers.Set("Date", time.Now().Format(time.RFC850))
2323

24-
err := bucket.Release(headers)
24+
err := bucket.Release(headers, id)
2525
if err != nil {
2626
t.Errorf("Release returned error: %v", err)
2727
}
@@ -50,15 +50,15 @@ func TestRatelimitGlobal(t *testing.T) {
5050
rl := NewRatelimiter()
5151

5252
sendReq := func(endpoint string) {
53-
bucket := rl.LockBucket(endpoint)
53+
bucket, id := rl.LockBucket(endpoint)
5454

5555
headers := http.Header(make(map[string][]string))
5656

5757
headers.Set("X-RateLimit-Global", "1")
5858
// Reset for approx 1 seconds from now
5959
headers.Set("Retry-After", "1000")
6060

61-
err := bucket.Release(headers)
61+
err := bucket.Release(headers, id)
6262
if err != nil {
6363
t.Errorf("Release returned error: %v", err)
6464
}
@@ -82,31 +82,35 @@ func TestRatelimitGlobal(t *testing.T) {
8282

8383
func BenchmarkRatelimitSingleEndpoint(b *testing.B) {
8484
rl := NewRatelimiter()
85+
rl.MaxConcurrentRequests = 10
8586
for i := 0; i < b.N; i++ {
8687
sendBenchReq("/guilds/99/channels", rl)
8788
}
8889
}
8990

9091
func BenchmarkRatelimitParallelMultiEndpoints(b *testing.B) {
9192
rl := NewRatelimiter()
93+
rl.MaxConcurrentRequests = 10
9294
b.RunParallel(func(pb *testing.PB) {
9395
i := 0
9496
for pb.Next() {
9597
sendBenchReq("/guilds/"+strconv.Itoa(i)+"/channels", rl)
96-
i++
98+
// i++
9799
}
98100
})
99101
}
100102

101103
// Does not actually send requests, but locks the bucket and releases it with made-up headers
102104
func sendBenchReq(endpoint string, rl *RateLimiter) {
103-
bucket := rl.LockBucket(endpoint)
105+
bucket, id := rl.LockBucket(endpoint)
104106

105107
headers := http.Header(make(map[string][]string))
106108

107109
headers.Set("X-RateLimit-Remaining", "10")
108110
headers.Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Unix(), 10))
109111
headers.Set("Date", time.Now().Format(time.RFC850))
110112

111-
bucket.Release(headers)
113+
time.Sleep(time.Millisecond * 100)
114+
115+
bucket.Release(headers, id)
112116
}

restapi.go

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ func (s *Session) request(method, urlStr, contentType string, b []byte, bucketID
7070
if bucketID == "" {
7171
bucketID = strings.SplitN(urlStr, "?", 2)[0]
7272
}
73-
return s.RequestWithLockedBucket(method, urlStr, contentType, b, s.Ratelimiter.LockBucket(bucketID))
73+
74+
return s.RequestWithBucket(method, urlStr, contentType, b, s.Ratelimiter.GetBucket(bucketID))
7475
}
7576

7677
type ReaderWithMockClose struct {
@@ -82,17 +83,12 @@ func (rwmc *ReaderWithMockClose) Close() error {
8283
}
8384

8485
// RequestWithLockedBucket makes a request using a bucket that's already been locked
85-
func (s *Session) RequestWithLockedBucket(method, urlStr, contentType string, b []byte, bucket *Bucket) (response []byte, err error) {
86+
func (s *Session) RequestWithBucket(method, urlStr, contentType string, b []byte, bucket *Bucket) (response []byte, err error) {
8687

8788
for i := 0; i < s.MaxRestRetries; i++ {
88-
if i != 0 {
89-
// bucket is unlocked during retry downtimes, lock it here again
90-
s.Ratelimiter.LockBucketObject(bucket)
91-
}
92-
9389
var retry bool
9490
var ratelimited bool
95-
response, retry, ratelimited, err = s.doRequestLockedBucket(method, urlStr, contentType, b, bucket)
91+
response, retry, ratelimited, err = s.doRequest(method, urlStr, contentType, b, bucket)
9692
if !retry {
9793
break
9894
}
@@ -114,58 +110,14 @@ const (
114110
CtxKeyRatelimitBucket CtxKey = iota
115111
)
116112

117-
// RequestWithLockedBucket makes a request using a bucket that's already been locked
118-
func (s *Session) doRequestLockedBucket(method, urlStr, contentType string, b []byte, bucket *Bucket) (response []byte, retry bool, ratelimitRetry bool, err error) {
119-
if s.Debug {
120-
log.Printf("API REQUEST %8s :: %s\n", method, urlStr)
121-
log.Printf("API REQUEST PAYLOAD :: [%s]\n", string(b))
122-
}
113+
// doRequest makes a request using a bucket
114+
func (s *Session) doRequest(method, urlStr, contentType string, b []byte, bucket *Bucket) (response []byte, retry bool, ratelimitRetry bool, err error) {
123115

124116
if atomic.LoadInt32(s.tokenInvalid) != 0 {
125117
return nil, false, false, ErrTokenInvalid
126118
}
127119

128-
req, err := http.NewRequest(method, urlStr, bytes.NewReader(b))
129-
if err != nil {
130-
bucket.Release(nil)
131-
return
132-
}
133-
134-
req.GetBody = func() (io.ReadCloser, error) {
135-
return &ReaderWithMockClose{bytes.NewReader(b)}, nil
136-
}
137-
138-
// Not used on initial login..
139-
// TODO: Verify if a login, otherwise complain about no-token
140-
if s.Token != "" {
141-
req.Header.Set("authorization", s.Token)
142-
}
143-
144-
// Discord's API returns a 400 Bad Request is Content-Type is set, but the
145-
// request body is empty.
146-
if b != nil {
147-
req.Header.Set("Content-Type", contentType)
148-
}
149-
150-
// TODO: Make a configurable static variable.
151-
req.Header.Set("User-Agent", fmt.Sprintf("DiscordBot (https://github.com/jonas747/discordgo, v%s)", VERSION))
152-
req.Header.Set("X-RateLimit-Precision", "millisecond")
153-
154-
// for things such as stats collecting in the roundtripper for example
155-
ctx := context.WithValue(req.Context(), CtxKeyRatelimitBucket, bucket)
156-
req = req.WithContext(ctx)
157-
158-
if s.Debug {
159-
for k, v := range req.Header {
160-
log.Printf("API REQUEST HEADER :: [%s] = %+v\n", k, v)
161-
}
162-
}
163-
164-
resp, err := s.Client.Do(req)
165-
if err != nil {
166-
bucket.Release(nil)
167-
return nil, true, false, err
168-
}
120+
req, resp, err := s.innerDoRequest(method, urlStr, contentType, b, bucket)
169121

170122
defer func() {
171123
err2 := resp.Body.Close()
@@ -174,11 +126,6 @@ func (s *Session) doRequestLockedBucket(method, urlStr, contentType string, b []
174126
}
175127
}()
176128

177-
err = bucket.Release(resp.Header)
178-
if err != nil {
179-
return
180-
}
181-
182129
response, err = ioutil.ReadAll(resp.Body)
183130
if err != nil {
184131
return nil, true, false, err
@@ -242,6 +189,63 @@ func (s *Session) doRequestLockedBucket(method, urlStr, contentType string, b []
242189
return
243190
}
244191

192+
func (s *Session) innerDoRequest(method, urlStr, contentType string, b []byte, bucket *Bucket) (*http.Request, *http.Response, error) {
193+
bucketLockID := s.Ratelimiter.LockBucketObject(bucket)
194+
defer func() {
195+
err := bucket.Release(nil, bucketLockID)
196+
if err != nil {
197+
s.log(LogError, "failed unlocking ratelimit bucket: %v", err)
198+
}
199+
}()
200+
201+
if s.Debug {
202+
log.Printf("API REQUEST %8s :: %s\n", method, urlStr)
203+
log.Printf("API REQUEST PAYLOAD :: [%s]\n", string(b))
204+
}
205+
206+
req, err := http.NewRequest(method, urlStr, bytes.NewReader(b))
207+
if err != nil {
208+
return nil, nil, err
209+
}
210+
211+
req.GetBody = func() (io.ReadCloser, error) {
212+
return &ReaderWithMockClose{bytes.NewReader(b)}, nil
213+
}
214+
215+
// Not used on initial login..
216+
// TODO: Verify if a login, otherwise complain about no-token
217+
if s.Token != "" {
218+
req.Header.Set("authorization", s.Token)
219+
}
220+
221+
// Discord's API returns a 400 Bad Request is Content-Type is set, but the
222+
// request body is empty.
223+
if b != nil {
224+
req.Header.Set("Content-Type", contentType)
225+
}
226+
227+
// TODO: Make a configurable static variable.
228+
req.Header.Set("User-Agent", fmt.Sprintf("DiscordBot (https://github.com/jonas747/discordgo, v%s)", VERSION))
229+
req.Header.Set("X-RateLimit-Precision", "millisecond")
230+
231+
// for things such as stats collecting in the roundtripper for example
232+
ctx := context.WithValue(req.Context(), CtxKeyRatelimitBucket, bucket)
233+
req = req.WithContext(ctx)
234+
235+
if s.Debug {
236+
for k, v := range req.Header {
237+
log.Printf("API REQUEST HEADER :: [%s] = %+v\n", k, v)
238+
}
239+
}
240+
241+
resp, err := s.Client.Do(req)
242+
if err == nil {
243+
err = bucket.Release(resp.Header, bucketLockID)
244+
}
245+
246+
return req, resp, err
247+
}
248+
245249
func unmarshal(data []byte, v interface{}) error {
246250
err := json.Unmarshal(data, v)
247251
return err

0 commit comments

Comments
 (0)