Skip to content

Commit 22836c7

Browse files
committed
fix: handle slow_down error
Signed-off-by: Babak K. Shandiz <babakks@github.com>
1 parent 6c44f68 commit 22836c7

File tree

3 files changed

+412
-21
lines changed

3 files changed

+412
-21
lines changed

device/device_flow.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ type WaitOptions struct {
148148

149149
// Wait polls the server at uri until authorization completes.
150150
func Wait(ctx context.Context, c httpClient, uri string, opts WaitOptions) (*api.AccessToken, error) {
151-
checkInterval := time.Duration(opts.DeviceCode.Interval) * time.Second
151+
baseCheckInterval := time.Duration(opts.DeviceCode.Interval) * time.Second
152152
expiresIn := time.Duration(opts.DeviceCode.ExpiresIn) * time.Second
153153
grantType := opts.GrantType
154154
if opts.GrantType == "" {
@@ -159,7 +159,7 @@ func Wait(ctx context.Context, c httpClient, uri string, opts WaitOptions) (*api
159159
if makePoller == nil {
160160
makePoller = newPoller
161161
}
162-
_, poll := makePoller(ctx, checkInterval, expiresIn)
162+
_, poll := makePoller(ctx, baseCheckInterval, expiresIn)
163163

164164
for {
165165
if err := poll.Wait(); err != nil {
@@ -187,8 +187,35 @@ func Wait(ctx context.Context, c httpClient, uri string, opts WaitOptions) (*api
187187
token, err := resp.AccessToken()
188188
if err == nil {
189189
return token, nil
190-
} else if !(errors.As(err, &apiError) && apiError.Code == "authorization_pending") {
190+
}
191+
192+
if !errors.As(err, &apiError) {
191193
return nil, err
192194
}
195+
196+
if apiError.Code == "authorization_pending" {
197+
// Keep polling
198+
continue
199+
}
200+
201+
if apiError.Code == "slow_down" {
202+
// Based on the RFC spec, we must add 5 seconds to our current polling interval.
203+
// (See https://www.rfc-editor.org/rfc/rfc8628#section-3.5)
204+
newInterval := poll.GetInterval() + 5*time.Second
205+
206+
// GitHub OAuth API returns the new interval in seconds in the response.
207+
// We should try to use that if provided. It's okay if we couldn't find
208+
// it as we have already increased our interval as of the RFC spec.
209+
if s := resp.Get("interval"); s != "" {
210+
if v, err := strconv.ParseUint(s, 10, 64); err == nil && v > 0 {
211+
newInterval = time.Duration(v) * time.Second
212+
}
213+
}
214+
215+
poll.SetInterval(newInterval)
216+
continue
217+
}
218+
219+
return nil, err
193220
}
194221
}

0 commit comments

Comments
 (0)