Skip to content

Commit 00c9d70

Browse files
committed
Add tests for the rate limit retry logic.
Signed-off-by: David Calavera <[email protected]>
1 parent 4225410 commit 00c9d70

File tree

4 files changed

+126
-15
lines changed

4 files changed

+126
-15
lines changed

go/Gopkg.lock

Lines changed: 19 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

go/Gopkg.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,7 @@
5252
[[constraint]]
5353
name = "github.com/Azure/go-autorest"
5454
version = "v9.6.0"
55+
56+
[[constraint]]
57+
name = "github.com/stretchr/testify"
58+
version = "v1.1.4"

go/porcelain/http/http.go

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,45 +26,68 @@ func NewRetryableTransport(tr runtime.ClientTransport, attempts int) *RetryableT
2626
}
2727
}
2828

29-
func (tr *RetryableTransport) Submit(op *runtime.ClientOperation) (interface{}, error) {
30-
op.Client.Transport = &retryableRoundTripper{
31-
tr: op.Client.Transport,
32-
attempts: tr.attempts,
29+
func (t *RetryableTransport) Submit(op *runtime.ClientOperation) (interface{}, error) {
30+
client := op.Client
31+
32+
if client == nil {
33+
client = http.DefaultClient
34+
}
35+
36+
transport := client.Transport
37+
if transport == nil {
38+
transport = http.DefaultTransport
3339
}
40+
client.Transport = &retryableRoundTripper{
41+
tr: transport,
42+
attempts: t.attempts,
43+
}
44+
45+
op.Client = client
3446

35-
return tr.Submit(op)
47+
return t.tr.Submit(op)
3648
}
3749

3850
func (tr *retryableRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
3951
rr := autorest.NewRetriableRequest(req)
4052

4153
// Increment to add the first call (attempts denotes number of retries)
42-
tr.attempts++
43-
for attempt := 0; attempt < tr.attempts; attempt++ {
54+
attempts := tr.attempts
55+
attempts++
56+
for attempt := 0; attempt < attempts; attempt++ {
4457
err = rr.Prepare()
4558
if err != nil {
4659
return resp, err
4760
}
48-
resp, err = tr.RoundTrip(rr.Request())
61+
62+
resp, err = tr.tr.RoundTrip(rr.Request())
63+
4964
if err != nil || resp.StatusCode != http.StatusTooManyRequests {
5065
return resp, err
5166
}
52-
delayWithRateLimit(resp, req.Cancel)
67+
68+
if !delayWithRateLimit(resp, req.Cancel) {
69+
return resp, err
70+
}
5371
}
72+
5473
return resp, err
5574
}
5675

57-
func delayWithRateLimit(resp *http.Response, cancel <-chan struct{}) {
58-
retryReset, err := strconv.ParseInt(resp.Header.Get("X-RateLimit-Reset"), 10, 0)
76+
func delayWithRateLimit(resp *http.Response, cancel <-chan struct{}) bool {
77+
r := resp.Header.Get("X-RateLimit-Reset")
78+
if r == "" {
79+
return false
80+
}
81+
retryReset, err := strconv.ParseInt(r, 10, 0)
5982
if err != nil {
60-
return
83+
return false
6184
}
6285

6386
t := time.Unix(retryReset, 0)
6487
select {
6588
case <-time.After(t.Sub(time.Now())):
66-
return
89+
return true
6790
case <-cancel:
68-
return
91+
return false
6992
}
7093
}

go/porcelain/http/http_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package http
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
8+
"net/url"
9+
"testing"
10+
"time"
11+
12+
"github.com/go-openapi/runtime"
13+
httptransport "github.com/go-openapi/runtime/client"
14+
"github.com/go-openapi/strfmt"
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
func TestRetryableTransport(t *testing.T) {
19+
attempts := 0
20+
21+
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
22+
attempts++
23+
24+
if attempts == 1 {
25+
reset := fmt.Sprintf("%d", time.Now().Add(1*time.Second).Unix())
26+
rw.Header().Set("X-RateLimit-Reset", reset)
27+
rw.WriteHeader(http.StatusTooManyRequests)
28+
_, _ = rw.Write([]byte("rate limited"))
29+
} else {
30+
rw.WriteHeader(http.StatusOK)
31+
_, _ = rw.Write([]byte("ok"))
32+
}
33+
}))
34+
defer server.Close()
35+
36+
rwrtr := runtime.ClientRequestWriterFunc(func(req runtime.ClientRequest, _ strfmt.Registry) error {
37+
return nil
38+
})
39+
40+
hu, _ := url.Parse(server.URL)
41+
rt := NewRetryableTransport(httptransport.New(hu.Host, "/", []string{"http"}), 2)
42+
43+
res, err := rt.Submit(&runtime.ClientOperation{
44+
ID: "getSites",
45+
Method: "GET",
46+
PathPattern: "/",
47+
Params: rwrtr,
48+
Reader: runtime.ClientResponseReaderFunc(func(response runtime.ClientResponse, consumer runtime.Consumer) (interface{}, error) {
49+
if response.Code() == 200 {
50+
var result string
51+
if err := consumer.Consume(response.Body(), &result); err != nil {
52+
return nil, err
53+
}
54+
return result, nil
55+
}
56+
return nil, errors.New("Generic error")
57+
}),
58+
})
59+
60+
require.NoError(t, err)
61+
require.Equal(t, 2, attempts)
62+
63+
require.IsType(t, "", res)
64+
actual := res.(string)
65+
require.EqualValues(t, "ok", actual)
66+
}

0 commit comments

Comments
 (0)