Skip to content

Commit 1b3b54b

Browse files
authored
Smithy SigV4 Middlwares and Retry Middleware Changes (#517)
1 parent 9a77f74 commit 1b3b54b

File tree

8 files changed

+384
-22
lines changed

8 files changed

+384
-22
lines changed

aws/middleware/middleware.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ func (r RequestInvocationIDMiddleware) HandleBuild(ctx context.Context, in middl
2828

2929
invocationID, err := sdk.UUIDVersion4()
3030
if err != nil {
31-
return out, middleware.NewMetadata(), err
31+
return out, metadata, err
3232
}
3333

3434
switch req := in.Request.(type) {
3535
case *smithyHTTP.Request:
3636
req.Header.Set(invocationIDHeader, invocationID)
3737
default:
38-
return middleware.BuildOutput{}, middleware.NewMetadata(), fmt.Errorf("unknown transport type %T", req)
38+
return out, metadata, fmt.Errorf("unknown transport type %T", req)
3939
}
4040

4141
return next.HandleBuild(ctx, in)
@@ -56,17 +56,18 @@ func (a AttemptClockSkewMiddleware) HandleDeserialize(ctx context.Context, in mi
5656
) {
5757
respMeta := ResponseMetadata{}
5858

59-
deserialize, metadata, err := next.HandleDeserialize(ctx, in)
59+
out, metadata, err = next.HandleDeserialize(ctx, in)
6060
respMeta.ResponseAt = sdk.NowTime()
6161

62-
switch resp := deserialize.RawResponse.(type) {
62+
switch resp := out.RawResponse.(type) {
6363
case *smithyHTTP.Response:
6464
respDateHeader := resp.Header.Get("Date")
6565
if len(respDateHeader) == 0 {
6666
break
6767
}
68-
respMeta.ServerTime, err = http.ParseTime(respDateHeader)
69-
if err != nil {
68+
var parseErr error
69+
respMeta.ServerTime, parseErr = http.ParseTime(respDateHeader)
70+
if parseErr != nil {
7071
// TODO: What should logging of errors look like?
7172
break
7273
}
@@ -76,9 +77,9 @@ func (a AttemptClockSkewMiddleware) HandleDeserialize(ctx context.Context, in mi
7677
respMeta.AttemptSkew = respMeta.ServerTime.Sub(respMeta.ResponseAt)
7778
}
7879

79-
SetResponseMetadata(metadata, respMeta)
80+
setResponseMetadata(&metadata, respMeta)
8081

81-
return deserialize, metadata, err
82+
return out, metadata, err
8283
}
8384

8485
type responseMetadataKey struct{}
@@ -96,7 +97,7 @@ func GetResponseMetadata(metadata middleware.Metadata) (v ResponseMetadata) {
9697
return v
9798
}
9899

99-
// SetResponseMetadata sets the ResponseMetadata on the given context
100-
func SetResponseMetadata(metadata middleware.Metadata, responseMetadata ResponseMetadata) {
100+
// setResponseMetadata sets the ResponseMetadata on the given context
101+
func setResponseMetadata(metadata *middleware.Metadata, responseMetadata ResponseMetadata) {
101102
metadata.Set(responseMetadataKey{}, responseMetadata)
102103
}

aws/middleware/middleware_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func TestRequestInvocationIDMiddleware(t *testing.T) {
3838
t.Errorf("invocation id was not a UUIDv4")
3939
}
4040

41-
return out, nil, err
41+
return out, metadata, err
4242
}))
4343
if err != nil {
4444
t.Errorf("expected no error, got %v", err)

aws/retry/middleware.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515

1616
// RequestCloner is a function that can take an input request type and clone the request
1717
// for use in a subsequent retry attempt
18-
type RequestCloner func(context.Context, interface{}) interface{}
18+
type RequestCloner func(interface{}) interface{}
1919

2020
type retryMetadata struct {
2121
AttemptNum int
@@ -54,22 +54,28 @@ func (r AttemptMiddleware) HandleFinalize(ctx context.Context, in smithymiddle.F
5454

5555
relRetryToken := r.retryer.GetInitialToken()
5656

57-
origReq := r.requestCloner(ctx, in.Request)
58-
origCtx := ctx
59-
6057
for {
6158
attemptNum++
6259

63-
ctx = setRetryMetadata(origCtx, retryMetadata{
60+
attemptInput := in
61+
attemptInput.Request = r.requestCloner(attemptInput.Request)
62+
63+
attemptCtx := setRetryMetadata(ctx, retryMetadata{
6464
AttemptNum: attemptNum,
6565
AttemptTime: sdk.NowTime(),
6666
MaxAttempts: maxAttempts,
6767
AttemptClockSkew: attemptClockSkew,
6868
})
6969

70-
in.Request = r.requestCloner(ctx, origReq)
70+
if attemptNum > 1 {
71+
if rewindable, ok := in.Request.(interface{ RewindStream() error }); ok {
72+
if err := rewindable.RewindStream(); err != nil {
73+
return out, metadata, fmt.Errorf("failed to rewind transport stream for retry, %w", err)
74+
}
75+
}
76+
}
7177

72-
out, metadata, reqErr := next.HandleFinalize(ctx, in)
78+
out, metadata, reqErr := next.HandleFinalize(attemptCtx, attemptInput)
7379

7480
relRetryToken(reqErr)
7581
if reqErr == nil {
@@ -125,7 +131,7 @@ func (r MetricsHeaderMiddleware) HandleFinalize(ctx context.Context, in smithymi
125131
) {
126132
retryMetadata, ok := getRetryMetadata(ctx)
127133
if !ok {
128-
return out, smithymiddle.NewMetadata(), fmt.Errorf("retry metadata value not found on context")
134+
return out, metadata, fmt.Errorf("retry metadata value not found on context")
129135
}
130136

131137
const retryMetricHeader = "amz-sdk-request"
@@ -152,7 +158,7 @@ func (r MetricsHeaderMiddleware) HandleFinalize(ctx context.Context, in smithymi
152158
case *http.Request:
153159
req.Header.Set(retryMetricHeader, strings.Join(parts, "; "))
154160
default:
155-
return smithymiddle.FinalizeOutput{}, smithymiddle.NewMetadata(), fmt.Errorf("unknown transport type %T", req)
161+
return out, metadata, fmt.Errorf("unknown transport type %T", req)
156162
}
157163

158164
return next.HandleFinalize(ctx, in)

aws/retry/middleware_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func TestMetricsHeaderMiddleware(t *testing.T) {
7373
t.Errorf("expected %v, got %v", e, a)
7474
}
7575

76-
return out, nil, err
76+
return out, metadata, err
7777
}))
7878
if err != nil && len(tt.expectedErr) == 0 {
7979
t.Fatalf("expected no error, got %q", err)

aws/signer/internal/v4/const.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ const (
44
// EmptyStringSHA256 is the hex encoded sha256 value of an empty string
55
EmptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
66

7+
// UnsignedPayload indicates that the request payload body is unsigned
8+
UnsignedPayload = "UNSIGNED-PAYLOAD"
9+
710
// AmzAlgorithmKey indicates the signing algorithm
811
AmzAlgorithmKey = "X-Amz-Algorithm"
912

aws/signer/v4/middleware.go

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
package v4
2+
3+
import (
4+
"context"
5+
"crypto/sha256"
6+
"encoding/hex"
7+
"fmt"
8+
"io"
9+
10+
v4Internal "github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4"
11+
"github.com/aws/aws-sdk-go-v2/internal/sdk"
12+
"github.com/awslabs/smithy-go/middleware"
13+
smithyHTTP "github.com/awslabs/smithy-go/transport/http"
14+
)
15+
16+
// HashComputationError indicates an error occurred while computing the signing hash
17+
type HashComputationError struct {
18+
Err error
19+
}
20+
21+
// Error is the error message
22+
func (e *HashComputationError) Error() string {
23+
return fmt.Sprintf("failed to compute payload hash: %v", e.Err)
24+
}
25+
26+
// Unwrap returns the underlying error if one is set
27+
func (e *HashComputationError) Unwrap() error {
28+
return e.Err
29+
}
30+
31+
// SigningError indicates an error condition occurred while performing SigV4 signing
32+
type SigningError struct {
33+
Err error
34+
}
35+
36+
func (e *SigningError) Error() string {
37+
return fmt.Sprintf("failed to sign request: %v", e.Err)
38+
}
39+
40+
// Unwrap returns the underlying error cause
41+
func (e *SigningError) Unwrap() error {
42+
return e.Err
43+
}
44+
45+
// UnsignedPayloadMiddleware sets the SigV4 request payload hash to unsigned
46+
type UnsignedPayloadMiddleware struct{}
47+
48+
// ID returns the UnsignedPayloadMiddleware identifier
49+
func (m *UnsignedPayloadMiddleware) ID() string {
50+
return "SigV4UnsignedPayloadMiddleware"
51+
}
52+
53+
// HandleFinalize sets the payload hash to be an unsigned payload
54+
func (m *UnsignedPayloadMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
55+
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
56+
) {
57+
ctx = SetPayloadHash(ctx, v4Internal.UnsignedPayload)
58+
return next.HandleFinalize(ctx, in)
59+
}
60+
61+
// ComputePayloadSHA256Middleware computes sha256 payload hash to sign
62+
type ComputePayloadSHA256Middleware struct{}
63+
64+
// ID is the middleware name
65+
func (m *ComputePayloadSHA256Middleware) ID() string {
66+
return "ComputePayloadSHA256Middleware"
67+
}
68+
69+
// HandleFinalize compute the payload hash for the request payload
70+
func (m *ComputePayloadSHA256Middleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
71+
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
72+
) {
73+
req, ok := in.Request.(*smithyHTTP.Request)
74+
if !ok {
75+
return out, metadata, &HashComputationError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
76+
}
77+
78+
hash := sha256.New()
79+
_, err = io.Copy(hash, req.GetStream())
80+
if err != nil {
81+
return out, metadata, &HashComputationError{Err: fmt.Errorf("failed to compute payload hash, %w", err)}
82+
}
83+
84+
if err := req.RewindStream(); err != nil {
85+
return out, metadata, &HashComputationError{Err: fmt.Errorf("failed to seek body to start, %w", err)}
86+
}
87+
88+
ctx = SetPayloadHash(ctx, hex.EncodeToString(hash.Sum(nil)))
89+
90+
return next.HandleFinalize(ctx, in)
91+
}
92+
93+
// SignHTTPRequestMiddleware is a `FinalizeMiddleware` implementation for SigV4 HTTP Signing
94+
type SignHTTPRequestMiddleware struct {
95+
signer HTTPSigner
96+
}
97+
98+
// NewSignHTTPRequestMiddleware constructs a SignHTTPRequestMiddleware using the given Signer for signing requests
99+
func NewSignHTTPRequestMiddleware(signer HTTPSigner) *SignHTTPRequestMiddleware {
100+
return &SignHTTPRequestMiddleware{signer: signer}
101+
}
102+
103+
// ID is the SignHTTPRequestMiddleware identifier
104+
func (s *SignHTTPRequestMiddleware) ID() string {
105+
return "SigV4SignHTTPRequestMiddleware"
106+
}
107+
108+
// HandleFinalize will take the provided input and sign the request using the SigV4 authentication scheme
109+
func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
110+
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
111+
) {
112+
req, ok := in.Request.(*smithyHTTP.Request)
113+
if !ok {
114+
return out, metadata, &SigningError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
115+
}
116+
117+
signingMetadata := GetSigningMetadata(ctx)
118+
payloadHash := GetPayloadHash(ctx)
119+
if len(payloadHash) == 0 {
120+
return out, metadata, &SigningError{Err: fmt.Errorf("computed payload hash missing from context")}
121+
}
122+
123+
err = s.signer.SignHTTP(ctx, req.Request, payloadHash, signingMetadata.SigningName, signingMetadata.SigningRegion, sdk.NowTime())
124+
if err != nil {
125+
return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
126+
}
127+
128+
return next.HandleFinalize(ctx, in)
129+
}
130+
131+
// SigningMetadata contains the signing name and signing region to be used when signing
132+
// with SigV4 authentication scheme.
133+
type SigningMetadata struct {
134+
SigningName string
135+
SigningRegion string
136+
}
137+
138+
type signingMetadataKey struct{}
139+
140+
// GetSigningMetadata retrieves the SigningMetadata from context. If there is no SigningMetadata attached to the context
141+
// an zero-value SigningMetadata will be returned.
142+
func GetSigningMetadata(ctx context.Context) (v SigningMetadata) {
143+
v, _ = ctx.Value(signingMetadataKey{}).(SigningMetadata)
144+
return v
145+
}
146+
147+
// SetSigningMetadata adds the provided metadata to the context
148+
func SetSigningMetadata(ctx context.Context, metadata SigningMetadata) context.Context {
149+
ctx = context.WithValue(ctx, signingMetadataKey{}, metadata)
150+
return ctx
151+
}
152+
153+
type payloadHashKey struct{}
154+
155+
// GetPayloadHash retrieves the payload hash to use for signing
156+
func GetPayloadHash(ctx context.Context) (v string) {
157+
v, _ = ctx.Value(payloadHashKey{}).(string)
158+
return v
159+
}
160+
161+
// SetPayloadHash sets the payload hash to be used for signing the request
162+
func SetPayloadHash(ctx context.Context, hash string) context.Context {
163+
ctx = context.WithValue(ctx, payloadHashKey{}, hash)
164+
return ctx
165+
}

0 commit comments

Comments
 (0)