Skip to content

Commit a1bd7b2

Browse files
committed
Restore original transport for each operation.
Signed-off-by: David Calavera <[email protected]>
1 parent 00c9d70 commit a1bd7b2

File tree

2 files changed

+91
-9
lines changed

2 files changed

+91
-9
lines changed

go/porcelain/http/http.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,29 +44,33 @@ func (t *RetryableTransport) Submit(op *runtime.ClientOperation) (interface{}, e
4444

4545
op.Client = client
4646

47-
return t.tr.Submit(op)
47+
res, err := t.tr.Submit(op)
48+
49+
// restore original transport
50+
op.Client.Transport = transport
51+
52+
return res, err
4853
}
4954

50-
func (tr *retryableRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
55+
func (t *retryableRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
5156
rr := autorest.NewRetriableRequest(req)
5257

53-
// Increment to add the first call (attempts denotes number of retries)
54-
attempts := tr.attempts
55-
attempts++
56-
for attempt := 0; attempt < attempts; attempt++ {
58+
for attempt := 0; attempt < t.attempts; attempt++ {
5759
err = rr.Prepare()
5860
if err != nil {
5961
return resp, err
6062
}
6163

62-
resp, err = tr.tr.RoundTrip(rr.Request())
64+
resp, err = t.tr.RoundTrip(rr.Request())
6365

6466
if err != nil || resp.StatusCode != http.StatusTooManyRequests {
6567
return resp, err
6668
}
6769

68-
if !delayWithRateLimit(resp, req.Cancel) {
69-
return resp, err
70+
if attempt+1 < t.attempts { // ignore delay check in the last request attempt
71+
if !delayWithRateLimit(resp, req.Cancel) {
72+
return resp, err
73+
}
7074
}
7175
}
7276

go/porcelain/http/http_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,81 @@ func TestRetryableTransport(t *testing.T) {
6464
actual := res.(string)
6565
require.EqualValues(t, "ok", actual)
6666
}
67+
68+
func TestRetryableTransportExceedsMaxAttempts(t *testing.T) {
69+
attempts := 0
70+
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
71+
attempts++
72+
reset := fmt.Sprintf("%d", time.Now().Add(1*time.Second).Unix())
73+
rw.Header().Set("X-RateLimit-Reset", reset)
74+
rw.WriteHeader(http.StatusTooManyRequests)
75+
_, _ = rw.Write([]byte("rate limited"))
76+
}))
77+
defer server.Close()
78+
79+
rwrtr := runtime.ClientRequestWriterFunc(func(req runtime.ClientRequest, _ strfmt.Registry) error {
80+
return nil
81+
})
82+
83+
hu, _ := url.Parse(server.URL)
84+
rt := NewRetryableTransport(httptransport.New(hu.Host, "/", []string{"http"}), 2)
85+
86+
_, err := rt.Submit(&runtime.ClientOperation{
87+
ID: "getSites",
88+
Method: "GET",
89+
PathPattern: "/",
90+
Params: rwrtr,
91+
Reader: runtime.ClientResponseReaderFunc(func(response runtime.ClientResponse, consumer runtime.Consumer) (interface{}, error) {
92+
if response.Code() == 200 {
93+
var result string
94+
if err := consumer.Consume(response.Body(), &result); err != nil {
95+
return nil, err
96+
}
97+
return result, nil
98+
}
99+
return nil, errors.New("Generic error")
100+
}),
101+
})
102+
103+
require.Error(t, err)
104+
require.Equal(t, 2, attempts)
105+
}
106+
107+
func TestRetryableWithDifferentError(t *testing.T) {
108+
attempts := 0
109+
110+
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
111+
attempts++
112+
113+
rw.WriteHeader(http.StatusNotFound)
114+
_, _ = rw.Write([]byte("not found"))
115+
}))
116+
defer server.Close()
117+
118+
rwrtr := runtime.ClientRequestWriterFunc(func(req runtime.ClientRequest, _ strfmt.Registry) error {
119+
return nil
120+
})
121+
122+
hu, _ := url.Parse(server.URL)
123+
rt := NewRetryableTransport(httptransport.New(hu.Host, "/", []string{"http"}), 2)
124+
125+
_, err := rt.Submit(&runtime.ClientOperation{
126+
ID: "getSites",
127+
Method: "GET",
128+
PathPattern: "/",
129+
Params: rwrtr,
130+
Reader: runtime.ClientResponseReaderFunc(func(response runtime.ClientResponse, consumer runtime.Consumer) (interface{}, error) {
131+
if response.Code() == 200 {
132+
var result string
133+
if err := consumer.Consume(response.Body(), &result); err != nil {
134+
return nil, err
135+
}
136+
return result, nil
137+
}
138+
return nil, errors.New("Generic error")
139+
}),
140+
})
141+
142+
require.Error(t, err)
143+
require.Equal(t, 1, attempts)
144+
}

0 commit comments

Comments
 (0)