Skip to content

Commit 7e732f1

Browse files
authored
aws: Fix BuildableHTTPClient datarace bug (#504)
Fixes a broken unit test missed when implementing #487.
1 parent 0f1fe1b commit 7e732f1

File tree

3 files changed

+77
-57
lines changed

3 files changed

+77
-57
lines changed

aws/http_client.go

Lines changed: 35 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ import (
66
"reflect"
77
"sync"
88
"time"
9-
10-
"golang.org/x/net/http2"
119
)
1210

1311
// Defaults for the HTTPTransportBuilder.
@@ -43,7 +41,7 @@ type BuildableHTTPClient struct {
4341
transport *http.Transport
4442
dialer *net.Dialer
4543

46-
initOnce *sync.Once
44+
initOnce sync.Once
4745

4846
clientTimeout time.Duration
4947
client *http.Client
@@ -52,9 +50,7 @@ type BuildableHTTPClient struct {
5250
// NewBuildableHTTPClient returns an initialized client for invoking HTTP
5351
// requests.
5452
func NewBuildableHTTPClient() *BuildableHTTPClient {
55-
return &BuildableHTTPClient{
56-
initOnce: new(sync.Once),
57-
}
53+
return &BuildableHTTPClient{}
5854
}
5955

6056
// Do implements the HTTPClient interface's Do method to invoke a HTTP request,
@@ -68,40 +64,25 @@ func NewBuildableHTTPClient() *BuildableHTTPClient {
6864
// Redirect (3xx) responses will not be followed, the HTTP response received
6965
// will returned instead.
7066
func (b *BuildableHTTPClient) Do(req *http.Request) (*http.Response, error) {
71-
b.initOnce.Do(b.initClient)
67+
b.initOnce.Do(b.build)
7268

7369
return b.client.Do(req)
7470
}
7571

76-
func (b *BuildableHTTPClient) initClient() {
77-
b.client = b.build()
78-
}
79-
80-
// BuildHTTPClient returns an initialized HTTPClient built from the options of
81-
// the builder.
82-
func (b BuildableHTTPClient) build() *http.Client {
83-
var tr *http.Transport
84-
if b.transport != nil {
85-
tr = shallowCopyStruct(b.transport).(*http.Transport)
86-
} else {
87-
tr = defaultHTTPTransport()
88-
}
89-
90-
// TODO Any way to ensure HTTP 2 is supported without depending on
91-
// an unversioned experimental package?
92-
// Maybe only clients that depend on HTTP/2 should call this?
93-
http2.ConfigureTransport(tr)
94-
95-
return wrapWithoutRedirect(&http.Client{
72+
func (b *BuildableHTTPClient) build() {
73+
b.client = wrapWithoutRedirect(&http.Client{
9674
Timeout: b.clientTimeout,
97-
Transport: tr,
75+
Transport: b.GetTransport(),
9876
})
9977
}
10078

101-
func (b BuildableHTTPClient) initReset() BuildableHTTPClient {
102-
b.initOnce = new(sync.Once)
103-
b.client = nil
104-
return b
79+
func (b *BuildableHTTPClient) clone() *BuildableHTTPClient {
80+
cpy := NewBuildableHTTPClient()
81+
cpy.transport = b.GetTransport()
82+
cpy.dialer = b.GetDialer()
83+
cpy.clientTimeout = b.clientTimeout
84+
85+
return cpy
10586
}
10687

10788
// WithTransportOptions copies the BuildableHTTPClient and returns it with the
@@ -110,51 +91,49 @@ func (b BuildableHTTPClient) initReset() BuildableHTTPClient {
11091
// If a non (*http.Transport) was set as the round tripper, the round tripper
11192
// will be replaced with a default Transport value before invoking the option
11293
// functions.
113-
func (b BuildableHTTPClient) WithTransportOptions(opts ...func(*http.Transport)) HTTPClient {
114-
b = b.initReset()
94+
func (b *BuildableHTTPClient) WithTransportOptions(opts ...func(*http.Transport)) HTTPClient {
95+
cpy := b.clone()
11596

116-
tr := b.GetTransport()
97+
tr := cpy.GetTransport()
11798
for _, opt := range opts {
11899
opt(tr)
119100
}
120-
b.transport = tr
101+
cpy.transport = tr
121102

122-
return &b
103+
return cpy
123104
}
124105

125106
// WithDialerOptions copies the BuildableHTTPClient and returns it with the
126107
// net.Dialer options applied. Will set the client's http.Transport DialContext
127108
// member.
128-
func (b BuildableHTTPClient) WithDialerOptions(opts ...func(*net.Dialer)) HTTPClient {
129-
b = b.initReset()
109+
func (b *BuildableHTTPClient) WithDialerOptions(opts ...func(*net.Dialer)) HTTPClient {
110+
cpy := b.clone()
130111

131-
dialer := b.GetDialer()
112+
dialer := cpy.GetDialer()
132113
for _, opt := range opts {
133114
opt(dialer)
134115
}
135-
b.dialer = dialer
116+
cpy.dialer = dialer
136117

137-
tr := b.GetTransport()
138-
tr.DialContext = b.dialer.DialContext
139-
b.transport = tr
118+
tr := cpy.GetTransport()
119+
tr.DialContext = cpy.dialer.DialContext
120+
cpy.transport = tr
140121

141-
return &b
122+
return cpy
142123
}
143124

144125
// WithTimeout Sets the timeout used by the client for all requests.
145-
func (b BuildableHTTPClient) WithTimeout(timeout time.Duration) HTTPClient {
146-
b = b.initReset()
147-
148-
b.clientTimeout = timeout
149-
150-
return &b
126+
func (b *BuildableHTTPClient) WithTimeout(timeout time.Duration) HTTPClient {
127+
cpy := b.clone()
128+
cpy.clientTimeout = timeout
129+
return cpy
151130
}
152131

153132
// GetTransport returns a copy of the client's HTTP Transport.
154-
func (b BuildableHTTPClient) GetTransport() *http.Transport {
133+
func (b *BuildableHTTPClient) GetTransport() *http.Transport {
155134
var tr *http.Transport
156135
if b.transport != nil {
157-
tr = shallowCopyStruct(b.transport).(*http.Transport)
136+
tr = b.transport.Clone()
158137
} else {
159138
tr = defaultHTTPTransport()
160139
}
@@ -163,7 +142,7 @@ func (b BuildableHTTPClient) GetTransport() *http.Transport {
163142
}
164143

165144
// GetDialer returns a copy of the client's network dialer.
166-
func (b BuildableHTTPClient) GetDialer() *net.Dialer {
145+
func (b *BuildableHTTPClient) GetDialer() *net.Dialer {
167146
var dialer *net.Dialer
168147
if b.dialer != nil {
169148
dialer = shallowCopyStruct(b.dialer).(*net.Dialer)
@@ -175,7 +154,7 @@ func (b BuildableHTTPClient) GetDialer() *net.Dialer {
175154
}
176155

177156
// GetTimeout returns a copy of the client's timeout to cancel requests with.
178-
func (b BuildableHTTPClient) GetTimeout() time.Duration {
157+
func (b *BuildableHTTPClient) GetTimeout() time.Duration {
179158
return b.clientTimeout
180159
}
181160

@@ -198,6 +177,7 @@ func defaultHTTPTransport() *http.Transport {
198177
MaxIdleConnsPerHost: DefaultHTTPTransportMaxIdleConnsPerHost,
199178
IdleConnTimeout: DefaultHTTPTransportIdleConnTimeout,
200179
ExpectContinueTimeout: DefaultHTTPTransportExpectContinueTimeout,
180+
ForceAttemptHTTP2: true,
201181
}
202182

203183
return tr

aws/http_client_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package aws_test
33
import (
44
"net/http"
55
"net/http/httptest"
6+
"sync"
67
"testing"
78
"time"
89

@@ -43,3 +44,43 @@ func TestBuildableHTTPClient_WithTimeout(t *testing.T) {
4344
t.Errorf("expect %v timeout, got %v", e, a)
4445
}
4546
}
47+
48+
func TestBuildableHTTPClient_concurrent(t *testing.T) {
49+
server := httptest.NewServer(http.HandlerFunc(
50+
func(w http.ResponseWriter, r *http.Request) {
51+
w.WriteHeader(200)
52+
}))
53+
defer server.Close()
54+
55+
var client aws.HTTPClient = aws.NewBuildableHTTPClient()
56+
57+
atOnce := 100
58+
var wg sync.WaitGroup
59+
wg.Add(atOnce)
60+
for i := 0; i < atOnce; i++ {
61+
go func(i int, client aws.HTTPClient) {
62+
defer wg.Done()
63+
64+
if v, ok := client.(interface{ GetTimeout() time.Duration }); ok {
65+
v.GetTimeout()
66+
}
67+
68+
if i%3 == 0 {
69+
if v, ok := client.(interface {
70+
WithTransportOptions(opts ...func(*http.Transport)) aws.HTTPClient
71+
}); ok {
72+
client = v.WithTransportOptions()
73+
}
74+
}
75+
76+
req, _ := http.NewRequest("GET", server.URL, nil)
77+
resp, err := client.Do(req)
78+
if err != nil {
79+
t.Errorf("expect no error, got %v", err)
80+
}
81+
resp.Body.Close()
82+
}(i, client)
83+
}
84+
85+
wg.Wait()
86+
}

service/s3/s3manager/upload_internal_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,9 @@ func TestUploadByteSlicePool_Failures(t *testing.T) {
174174
}
175175

176176
if r.Operation.Name == operation {
177-
r.Retryable = aws.Bool(false)
178177
r.Error = fmt.Errorf("request error")
179178
r.HTTPResponse = &http.Response{
180-
StatusCode: 500,
179+
StatusCode: 400,
181180
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
182181
}
183182
return

0 commit comments

Comments
 (0)