Skip to content

Commit dcba04b

Browse files
authored
Add SigV4 middleware for aws-sdk-go-v2 (#257)
* Add new sigv4 middleware for aws-sdk-go-v2 * Improve header handling & support pdc
1 parent 292bbb3 commit dcba04b

File tree

6 files changed

+282
-27
lines changed

6 files changed

+282
-27
lines changed

pkg/awsauth/auth.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@ package awsauth
33
import (
44
"context"
55
"fmt"
6-
"github.com/grafana/grafana-aws-sdk/pkg/awsds"
7-
86
"github.com/aws/aws-sdk-go-v2/aws"
7+
"github.com/grafana/grafana-aws-sdk/pkg/awsds"
98
"github.com/grafana/grafana-plugin-sdk-go/backend"
109
"strings"
1110
)

pkg/awsauth/settings.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
smithymiddleware "github.com/aws/smithy-go/middleware"
1919

2020
"github.com/grafana/grafana-aws-sdk/pkg/awsds"
21+
"github.com/grafana/grafana-aws-sdk/pkg/common"
2122
"github.com/grafana/grafana-plugin-sdk-go/backend/proxy"
2223
"github.com/grafana/grafana-plugin-sdk-go/build"
2324
)
@@ -144,6 +145,9 @@ func (s Settings) WithGrafanaAssumeRole(ctx context.Context, client AWSAPIClient
144145
}
145146

146147
func (s Settings) WithAssumeRole(cfg aws.Config, client AWSAPIClient) LoadOptionsFunc {
148+
if common.IsOptInRegion(cfg.Region) {
149+
cfg.Region = "us-east-1"
150+
}
147151
stsClient := client.NewSTSClientFromConfig(cfg)
148152
provider := client.NewAssumeRoleProvider(stsClient, s.AssumeRoleARN, func(options *stscreds.AssumeRoleOptions) {
149153
if s.ExternalID != "" {

pkg/awsauth/sigv4.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package awsauth
2+
3+
import (
4+
"context"
5+
"crypto/sha256"
6+
"encoding/hex"
7+
"io"
8+
"net/http"
9+
"time"
10+
11+
"github.com/aws/aws-sdk-go-v2/aws"
12+
"github.com/aws/aws-sdk-go-v2/aws/signer/v4"
13+
14+
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
15+
)
16+
17+
func NewSigV4Middleware(signerOpts ...func(signer *v4.SignerOptions)) httpclient.Middleware {
18+
return SignerMiddleware{signerOpts}
19+
}
20+
21+
type SignerMiddleware struct {
22+
signerOpts []func(*v4.SignerOptions)
23+
}
24+
25+
func (s SignerMiddleware) CreateMiddleware(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
26+
if opts.SigV4 == nil {
27+
return next
28+
}
29+
return NewSignerRoundTripper(opts, next, v4.NewSigner(s.signerOpts...))
30+
}
31+
32+
func (s SignerMiddleware) MiddlewareName() string {
33+
return "sigv4"
34+
}
35+
36+
func NewSignerRoundTripper(opts httpclient.Options, next http.RoundTripper, signer v4.HTTPSigner) SignerRoundTripper {
37+
return SignerRoundTripper{
38+
httpOptions: opts,
39+
next: next,
40+
awsConfigProvider: NewConfigProvider(),
41+
signer: signer,
42+
clock: systemClock{},
43+
}
44+
}
45+
46+
type SignerRoundTripper struct {
47+
httpOptions httpclient.Options
48+
next http.RoundTripper
49+
awsConfigProvider ConfigProvider
50+
signer v4.HTTPSigner
51+
clock Clock
52+
}
53+
54+
func (s SignerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
55+
awsAuthSettings := Settings{
56+
AuthType: AuthType(s.httpOptions.SigV4.AuthType),
57+
AccessKey: s.httpOptions.SigV4.AccessKey,
58+
SecretKey: s.httpOptions.SigV4.SecretKey,
59+
Region: s.httpOptions.SigV4.Region,
60+
CredentialsProfile: s.httpOptions.SigV4.Profile,
61+
AssumeRoleARN: s.httpOptions.SigV4.AssumeRoleARN,
62+
ExternalID: s.httpOptions.SigV4.ExternalID,
63+
ProxyOptions: s.httpOptions.ProxyOptions,
64+
}
65+
ctx := req.Context()
66+
cfg, err := s.awsConfigProvider.GetConfig(ctx, awsAuthSettings)
67+
if err != nil {
68+
return nil, err
69+
}
70+
credentials, err := cfg.Credentials.Retrieve(ctx)
71+
if err != nil {
72+
return nil, err
73+
}
74+
err = s.SignHTTP(ctx, req, credentials)
75+
if err != nil {
76+
return nil, err
77+
}
78+
return s.next.RoundTrip(req)
79+
}
80+
81+
func (s SignerRoundTripper) SignHTTP(ctx context.Context, req *http.Request, credentials aws.Credentials) error {
82+
// we start req with empty headers since that's what the signer is expecting,
83+
// but add them back at the end
84+
headers := req.Header
85+
req.Header = make(http.Header)
86+
defer func() {
87+
// replace the custom headers before returning
88+
for k, v := range headers {
89+
req.Header[k] = v
90+
}
91+
}()
92+
payloadHash, err := getRequestBodyHash(req)
93+
if err != nil {
94+
return err
95+
}
96+
return s.signer.SignHTTP(ctx, credentials, req, payloadHash, s.httpOptions.SigV4.Service, s.httpOptions.SigV4.Region, s.clock.Now().UTC())
97+
}
98+
99+
func getRequestBodyHash(req *http.Request) (string, error) {
100+
body, err := req.GetBody()
101+
if err != nil {
102+
return "", err
103+
}
104+
hash := sha256.New()
105+
_, err = io.Copy(hash, body)
106+
if err != nil {
107+
return "", err
108+
}
109+
return hex.EncodeToString(hash.Sum(nil)), nil
110+
111+
}
112+
113+
type Clock interface {
114+
Now() time.Time
115+
}
116+
type systemClock struct{}
117+
118+
func (systemClock) Now() time.Time { return time.Now() }

pkg/awsauth/sigv4_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package awsauth
2+
3+
import (
4+
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
5+
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
"net/http"
9+
"strings"
10+
"testing"
11+
"time"
12+
)
13+
14+
const EmptySha256Hash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
15+
16+
var OnceUponATime = time.Unix(1234567890, 0) // 2009-02-13 UTC
17+
var AtALaterTime = time.Unix(1234567891, 0) // 2009-02-13 UTC
18+
19+
func TestSignerRoundTripper_SignHTTP(t *testing.T) {
20+
tests := []struct {
21+
name string
22+
sigV4Config *httpclient.SigV4Config
23+
requestBody string
24+
customHeaders http.Header
25+
differentTimes bool
26+
}{
27+
{
28+
name: "basic success",
29+
sigV4Config: &httpclient.SigV4Config{
30+
AuthType: "keys",
31+
AccessKey: "good",
32+
SecretKey: "excellent",
33+
Region: "us-east-1",
34+
},
35+
},
36+
{
37+
name: "with custom headers",
38+
sigV4Config: &httpclient.SigV4Config{
39+
AuthType: "keys",
40+
AccessKey: "good",
41+
SecretKey: "excellent",
42+
Region: "us-east-1",
43+
},
44+
customHeaders: http.Header{"X-Testing-Stuff": []string{"is good"}},
45+
},
46+
{
47+
name: "signature changes with different time",
48+
sigV4Config: &httpclient.SigV4Config{
49+
AuthType: "keys",
50+
AccessKey: "good",
51+
SecretKey: "excellent",
52+
Region: "us-east-1",
53+
},
54+
differentTimes: true,
55+
},
56+
}
57+
for _, tt := range tests {
58+
t.Run(tt.name, func(t *testing.T) {
59+
next := &testRoundTripper{}
60+
s := NewSignerRoundTripper(httpclient.Options{SigV4: tt.sigV4Config}, next, v4.NewSigner())
61+
s.awsConfigProvider = NewFakeConfigProvider(false)
62+
s.clock = staticClock{OnceUponATime}
63+
64+
req, _ := http.NewRequest("GET", "https://service.aws.amazon.notreally", strings.NewReader(tt.requestBody))
65+
_, err := s.RoundTrip(req)
66+
require.NoError(t, err)
67+
require.NotEmpty(t, req.Header["Authorization"])
68+
69+
if tt.customHeaders != nil {
70+
reqWithHeaders, _ := http.NewRequest("GET", "https://service.aws.amazon.notreally", strings.NewReader(tt.requestBody))
71+
reqWithHeaders.Header = tt.customHeaders
72+
_, err = s.RoundTrip(reqWithHeaders)
73+
require.NoError(t, err)
74+
75+
// custom headers should not affect the signature
76+
require.Equal(t, req.Header["Authorization"], reqWithHeaders.Header["Authorization"])
77+
// ... but should be retained
78+
for k, v := range tt.customHeaders {
79+
require.Equal(t, v, reqWithHeaders.Header[k])
80+
}
81+
}
82+
if tt.differentTimes {
83+
s.clock = staticClock{AtALaterTime}
84+
reqLater, _ := http.NewRequest("GET", "https://service.aws.amazon.notreally", strings.NewReader(tt.requestBody))
85+
_, err = s.RoundTrip(reqLater)
86+
require.NoError(t, err)
87+
require.NotEqual(t, req.Header["Authorization"], reqLater.Header["Authorization"])
88+
89+
}
90+
})
91+
}
92+
}
93+
func Test_getRequestBodyHash(t *testing.T) {
94+
tests := []struct {
95+
name string
96+
body string
97+
expected string
98+
}{
99+
{
100+
name: "empty body is empty hash",
101+
body: "",
102+
expected: EmptySha256Hash,
103+
},
104+
{
105+
name: "hello world",
106+
body: "hello world",
107+
expected: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9",
108+
},
109+
}
110+
for _, tt := range tests {
111+
t.Run(tt.name, func(t *testing.T) {
112+
req, _ := http.NewRequest("get", "https://whatever.wherever:999", strings.NewReader(tt.body))
113+
got, _ := getRequestBodyHash(req)
114+
assert.Equalf(t, tt.expected, got, "getRequestBodyHash(%v)", req)
115+
})
116+
}
117+
}
118+
119+
type staticClock struct {
120+
when time.Time
121+
}
122+
123+
func (s staticClock) Now() time.Time { return s.when }
124+
125+
type testRoundTripper struct {
126+
seen *http.Request
127+
}
128+
129+
func (t *testRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
130+
t.seen = request
131+
return &http.Response{Status: "everything is awesome", StatusCode: 200}, nil
132+
}

pkg/awsds/sessions.go

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"sync"
1111
"time"
1212

13+
"github.com/grafana/grafana-aws-sdk/pkg/common"
14+
1315
"github.com/grafana/grafana-plugin-sdk-go/backend"
1416
"github.com/grafana/grafana-plugin-sdk-go/experimental/errorsource"
1517

@@ -101,30 +103,6 @@ type SessionConfig struct {
101103
AuthSettings *AuthSettings
102104
}
103105

104-
func isOptInRegion(region string) bool {
105-
// Opt-in region from https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html#concepts-available-regions
106-
regions := map[string]bool{
107-
"af-south-1": true,
108-
"ap-east-1": true,
109-
"ap-east-2": true,
110-
"ap-south-2": true,
111-
"ap-southeast-3": true,
112-
"ap-southeast-4": true,
113-
"ap-southeast-5": true,
114-
"ap-southeast-7": true,
115-
"ca-west-1": true,
116-
"eu-central-2": true,
117-
"eu-south-1": true,
118-
"eu-south-2": true,
119-
"il-central-1": true,
120-
"me-central-1": true,
121-
"me-south-1": true,
122-
"mx-central-1": true,
123-
// The rest of regions will return false
124-
}
125-
return regions[region]
126-
}
127-
128106
// Deprecated: use GetSessionWithAuthSettings instead
129107
func (sc *SessionCache) GetSession(c SessionConfig) (*session.Session, error) {
130108
if c.Settings.Region == "" && c.Settings.DefaultRegion != "" {
@@ -193,7 +171,7 @@ func (sc *SessionCache) GetSession(c SessionConfig) (*session.Session, error) {
193171
c.Settings.Region = ""
194172
}
195173
if c.Settings.Region != "" {
196-
if c.Settings.AssumeRoleARN != "" && c.AuthSettings.AssumeRoleEnabled && isOptInRegion(c.Settings.Region) {
174+
if c.Settings.AssumeRoleARN != "" && c.AuthSettings.AssumeRoleEnabled && common.IsOptInRegion(c.Settings.Region) {
197175
// When assuming a role, the real region is set later in a new session
198176
// so we use a well-known region here (not opt-in) to obtain valid credentials
199177
regionCfg = &aws.Config{Region: aws.String("us-east-1")}

pkg/common/opt_in.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package common
2+
3+
func IsOptInRegion(region string) bool {
4+
// Opt-in regions listed at https://docs.aws.amazon.com/global-infrastructure/latest/regions/aws-regions.html
5+
regions := map[string]bool{
6+
"af-south-1": true,
7+
"ap-east-1": true,
8+
"ap-east-2": true,
9+
"ap-south-2": true,
10+
"ap-southeast-3": true,
11+
"ap-southeast-4": true,
12+
"ap-southeast-5": true,
13+
"ap-southeast-7": true,
14+
"ca-west-1": true,
15+
"eu-central-2": true,
16+
"eu-south-1": true,
17+
"eu-south-2": true,
18+
"il-central-1": true,
19+
"me-central-1": true,
20+
"me-south-1": true,
21+
"mx-central-1": true,
22+
}
23+
return regions[region]
24+
}

0 commit comments

Comments
 (0)