Skip to content

Commit 0d3bd7a

Browse files
authored
feature/ec2/imds: Fix Client's response handling and operation timeout race (#1448)
Fixes #1253 race between reading a IMDS response, and the operationTimeout middleware cleaning up its timeout context. Changes the IMDS client to always buffer the response body received, before the result is deserialized. This ensures that the consumer of the operation's response body will not race with context cleanup within the middleware stack. Updates the IMDS Client operations to not override the passed in Context's Deadline or Timeout options. If an Client operation is called with a Context with a Deadline or Timeout, the client will no longer override it with the client's default timeout. Updates operationTimeout so that if DefaultTimeout is unset (aka zero) operationTimeout will not set a default timeout on the context.
1 parent f1baf2d commit 0d3bd7a

File tree

5 files changed

+196
-10
lines changed

5 files changed

+196
-10
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "17ac8941-9cd9-4e59-8fcc-93444709cc1a",
3+
"type": "feature",
4+
"description": "Respect passed in Context Deadline/Timeout. Updates the IMDS Client operations to not override the passed in Context's Deadline or Timeout options. If an Client operation is called with a Context with a Deadline or Timeout, the client will no longer override it with the client's default timeout.",
5+
"modules": [
6+
"feature/ec2/imds"
7+
]
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "53dad1d6-8586-4ddf-b030-4829c2a45e4c",
3+
"type": "bugfix",
4+
"description": "Fix IMDS client's response handling and operation timeout race. Fixes #1253",
5+
"modules": [
6+
"feature/ec2/imds"
7+
]
8+
}

feature/ec2/imds/doc.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
// Package imds provides the API client for interacting with the Amazon EC2
22
// Instance Metadata Service.
33
//
4+
// All Client operation calls have a default timeout. If the operation is not
5+
// completed before this timeout expires, the operation will be canceled. This
6+
// timeout can be overridden by providing Context with a timeout or deadline
7+
// with calling the client's operations.
8+
//
49
// See the EC2 IMDS user guide for more information on using the API.
510
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html
611
package imds

feature/ec2/imds/request_middleware.go

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package imds
22

33
import (
4+
"bytes"
45
"context"
56
"fmt"
7+
"io/ioutil"
68
"net/url"
79
"path"
810
"time"
@@ -52,7 +54,7 @@ func addRequestMiddleware(stack *middleware.Stack,
5254

5355
// Operation timeout
5456
err = stack.Initialize.Add(&operationTimeout{
55-
Timeout: defaultOperationTimeout,
57+
DefaultTimeout: defaultOperationTimeout,
5658
}, middleware.Before)
5759
if err != nil {
5860
return err
@@ -142,12 +144,20 @@ func (m *deserializeResponse) HandleDeserialize(
142144
resp, ok := out.RawResponse.(*smithyhttp.Response)
143145
if !ok {
144146
return out, metadata, fmt.Errorf(
145-
"unexpected transport response type, %T", out.RawResponse)
147+
"unexpected transport response type, %T, want %T", out.RawResponse, resp)
146148
}
149+
defer resp.Body.Close()
147150

148-
// Anything thats not 200 |< 300 is error
151+
// read the full body so that any operation timeouts cleanup will not race
152+
// the body being read.
153+
body, err := ioutil.ReadAll(resp.Body)
154+
if err != nil {
155+
return out, metadata, fmt.Errorf("read response body failed, %w", err)
156+
}
157+
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
158+
159+
// Anything that's not 200 |< 300 is error
149160
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
150-
resp.Body.Close()
151161
return out, metadata, &smithyhttp.ResponseError{
152162
Response: resp,
153163
Err: fmt.Errorf("request to EC2 IMDS failed"),
@@ -213,8 +223,19 @@ const (
213223
defaultOperationTimeout = 5 * time.Second
214224
)
215225

226+
// operationTimeout adds a timeout on the middleware stack if the Context the
227+
// stack was called with does not have a deadline. The next middleware must
228+
// complete before the timeout, or the context will be canceled.
229+
//
230+
// If DefaultTimeout is zero, no default timeout will be used if the Context
231+
// does not have a timeout.
232+
//
233+
// The next middleware must also ensure that any resources that are also
234+
// canceled by the stack's context are completely consumed before returning.
235+
// Otherwise the timeout cleanup will race the resource being consumed
236+
// upstream.
216237
type operationTimeout struct {
217-
Timeout time.Duration
238+
DefaultTimeout time.Duration
218239
}
219240

220241
func (*operationTimeout) ID() string { return "OperationTimeout" }
@@ -224,10 +245,11 @@ func (m *operationTimeout) HandleInitialize(
224245
) (
225246
output middleware.InitializeOutput, metadata middleware.Metadata, err error,
226247
) {
227-
var cancelFn func()
228-
229-
ctx, cancelFn = context.WithTimeout(ctx, m.Timeout)
230-
defer cancelFn()
248+
if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
249+
var cancelFn func()
250+
ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
251+
defer cancelFn()
252+
}
231253

232254
return next.HandleInitialize(ctx, input)
233255
}

feature/ec2/imds/request_middleware_test.go

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"encoding/hex"
7+
"fmt"
78
"io"
89
"io/ioutil"
910
"net/http"
@@ -126,7 +127,7 @@ func TestAddRequestMiddleware(t *testing.T) {
126127

127128
func TestOperationTimeoutMiddleware(t *testing.T) {
128129
m := &operationTimeout{
129-
Timeout: time.Nanosecond,
130+
DefaultTimeout: time.Nanosecond,
130131
}
131132

132133
_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
@@ -135,6 +136,10 @@ func TestOperationTimeoutMiddleware(t *testing.T) {
135136
) (
136137
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
137138
) {
139+
if _, ok := ctx.Deadline(); !ok {
140+
return out, metadata, fmt.Errorf("expect context deadline to be set")
141+
}
142+
138143
if err := sdk.SleepWithContext(ctx, time.Second); err != nil {
139144
return out, metadata, err
140145
}
@@ -150,6 +155,144 @@ func TestOperationTimeoutMiddleware(t *testing.T) {
150155
}
151156
}
152157

158+
func TestOperationTimeoutMiddleware_noDefaultTimeout(t *testing.T) {
159+
m := &operationTimeout{}
160+
161+
_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
162+
middleware.InitializeHandlerFunc(func(
163+
ctx context.Context, input middleware.InitializeInput,
164+
) (
165+
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
166+
) {
167+
if t, ok := ctx.Deadline(); ok {
168+
return out, metadata, fmt.Errorf("expect no context deadline, got %v", t)
169+
}
170+
171+
return out, metadata, nil
172+
}))
173+
if err != nil {
174+
t.Fatalf("expect no error, got %v", err)
175+
}
176+
}
177+
178+
func TestOperationTimeoutMiddleware_withCustomDeadline(t *testing.T) {
179+
m := &operationTimeout{
180+
DefaultTimeout: time.Nanosecond,
181+
}
182+
183+
expectDeadline := time.Now().Add(time.Hour)
184+
ctx, cancelFn := context.WithDeadline(context.Background(), expectDeadline)
185+
defer cancelFn()
186+
187+
_, _, err := m.HandleInitialize(ctx, middleware.InitializeInput{},
188+
middleware.InitializeHandlerFunc(func(
189+
ctx context.Context, input middleware.InitializeInput,
190+
) (
191+
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
192+
) {
193+
t, ok := ctx.Deadline()
194+
if !ok {
195+
return out, metadata, fmt.Errorf("expect context deadline to be set")
196+
}
197+
if e, a := expectDeadline, t; !e.Equal(a) {
198+
return out, metadata, fmt.Errorf("expect %v deadline, got %v", e, a)
199+
}
200+
201+
return out, metadata, nil
202+
}))
203+
if err != nil {
204+
t.Fatalf("expect no error, got %v", err)
205+
}
206+
}
207+
208+
// Ensure that the response body is read in the deserialize middleware,
209+
// ensuring that the timeoutOperation middleware won't race canceling the
210+
// context with the upstream reading the response body.
211+
// * https://github.com/aws/aws-sdk-go-v2/issues/1253
212+
func TestDeserailizeResponse_cacheBody(t *testing.T) {
213+
type Output struct {
214+
Content io.ReadCloser
215+
}
216+
m := &deserializeResponse{
217+
GetOutput: func(resp *smithyhttp.Response) (interface{}, error) {
218+
return &Output{
219+
Content: resp.Body,
220+
}, nil
221+
},
222+
}
223+
224+
expectBody := "hello world!"
225+
originalBody := &bytesReader{
226+
reader: strings.NewReader(expectBody),
227+
}
228+
if originalBody.closed {
229+
t.Fatalf("expect original body not to be closed yet")
230+
}
231+
232+
out, _, err := m.HandleDeserialize(context.Background(), middleware.DeserializeInput{},
233+
middleware.DeserializeHandlerFunc(func(
234+
ctx context.Context, input middleware.DeserializeInput,
235+
) (
236+
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
237+
) {
238+
out.RawResponse = &smithyhttp.Response{
239+
Response: &http.Response{
240+
StatusCode: 200,
241+
Status: "200 OK",
242+
Header: http.Header{},
243+
ContentLength: int64(originalBody.Len()),
244+
Body: originalBody,
245+
},
246+
}
247+
return out, metadata, nil
248+
}))
249+
if err != nil {
250+
t.Fatalf("expect no error, got %v", err)
251+
}
252+
253+
if !originalBody.closed {
254+
t.Errorf("expect original body to be closed, was not")
255+
}
256+
257+
result, ok := out.Result.(*Output)
258+
if !ok {
259+
t.Fatalf("expect result to be Output, got %T, %v", result, result)
260+
}
261+
262+
actualBody, err := ioutil.ReadAll(result.Content)
263+
if err != nil {
264+
t.Fatalf("expect no error, got %v", err)
265+
}
266+
if e, a := expectBody, string(actualBody); e != a {
267+
t.Errorf("expect %v body, got %v", e, a)
268+
}
269+
if err := result.Content.Close(); err != nil {
270+
t.Fatalf("expect no error, got %v", err)
271+
}
272+
}
273+
274+
type bytesReader struct {
275+
reader interface {
276+
io.Reader
277+
Len() int
278+
}
279+
closed bool
280+
}
281+
282+
func (r *bytesReader) Len() int {
283+
return r.reader.Len()
284+
}
285+
func (r *bytesReader) Close() error {
286+
r.closed = true
287+
return nil
288+
}
289+
func (r *bytesReader) Read(p []byte) (int, error) {
290+
if r.closed {
291+
return 0, io.EOF
292+
}
293+
return r.reader.Read(p)
294+
}
295+
153296
type successAPIResponseHandler struct {
154297
t *testing.T
155298
path string

0 commit comments

Comments
 (0)