|
| 1 | +package awsauth |
| 2 | + |
| 3 | +import ( |
| 4 | + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" |
| 5 | + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
| 6 | + "github.com/stretchr/testify/assert" |
| 7 | + "github.com/stretchr/testify/require" |
| 8 | + "net/http" |
| 9 | + "strings" |
| 10 | + "testing" |
| 11 | + "time" |
| 12 | +) |
| 13 | + |
| 14 | +const EmptySha256Hash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" |
| 15 | + |
| 16 | +var OnceUponATime = time.Unix(1234567890, 0) // 2009-02-13 UTC |
| 17 | +var AtALaterTime = time.Unix(1234567891, 0) // 2009-02-13 UTC |
| 18 | + |
| 19 | +func TestSignerRoundTripper_SignHTTP(t *testing.T) { |
| 20 | + tests := []struct { |
| 21 | + name string |
| 22 | + sigV4Config *httpclient.SigV4Config |
| 23 | + requestBody string |
| 24 | + customHeaders http.Header |
| 25 | + differentTimes bool |
| 26 | + }{ |
| 27 | + { |
| 28 | + name: "basic success", |
| 29 | + sigV4Config: &httpclient.SigV4Config{ |
| 30 | + AuthType: "keys", |
| 31 | + AccessKey: "good", |
| 32 | + SecretKey: "excellent", |
| 33 | + Region: "us-east-1", |
| 34 | + }, |
| 35 | + }, |
| 36 | + { |
| 37 | + name: "with custom headers", |
| 38 | + sigV4Config: &httpclient.SigV4Config{ |
| 39 | + AuthType: "keys", |
| 40 | + AccessKey: "good", |
| 41 | + SecretKey: "excellent", |
| 42 | + Region: "us-east-1", |
| 43 | + }, |
| 44 | + customHeaders: http.Header{"X-Testing-Stuff": []string{"is good"}}, |
| 45 | + }, |
| 46 | + { |
| 47 | + name: "signature changes with different time", |
| 48 | + sigV4Config: &httpclient.SigV4Config{ |
| 49 | + AuthType: "keys", |
| 50 | + AccessKey: "good", |
| 51 | + SecretKey: "excellent", |
| 52 | + Region: "us-east-1", |
| 53 | + }, |
| 54 | + differentTimes: true, |
| 55 | + }, |
| 56 | + } |
| 57 | + for _, tt := range tests { |
| 58 | + t.Run(tt.name, func(t *testing.T) { |
| 59 | + next := &testRoundTripper{} |
| 60 | + s := NewSignerRoundTripper(httpclient.Options{SigV4: tt.sigV4Config}, next, v4.NewSigner()) |
| 61 | + s.awsConfigProvider = NewFakeConfigProvider(false) |
| 62 | + s.clock = staticClock{OnceUponATime} |
| 63 | + |
| 64 | + req, _ := http.NewRequest("GET", "https://service.aws.amazon.notreally", strings.NewReader(tt.requestBody)) |
| 65 | + _, err := s.RoundTrip(req) |
| 66 | + require.NoError(t, err) |
| 67 | + require.NotEmpty(t, req.Header["Authorization"]) |
| 68 | + |
| 69 | + if tt.customHeaders != nil { |
| 70 | + reqWithHeaders, _ := http.NewRequest("GET", "https://service.aws.amazon.notreally", strings.NewReader(tt.requestBody)) |
| 71 | + reqWithHeaders.Header = tt.customHeaders |
| 72 | + _, err = s.RoundTrip(reqWithHeaders) |
| 73 | + require.NoError(t, err) |
| 74 | + |
| 75 | + // custom headers should not affect the signature |
| 76 | + require.Equal(t, req.Header["Authorization"], reqWithHeaders.Header["Authorization"]) |
| 77 | + // ... but should be retained |
| 78 | + for k, v := range tt.customHeaders { |
| 79 | + require.Equal(t, v, reqWithHeaders.Header[k]) |
| 80 | + } |
| 81 | + } |
| 82 | + if tt.differentTimes { |
| 83 | + s.clock = staticClock{AtALaterTime} |
| 84 | + reqLater, _ := http.NewRequest("GET", "https://service.aws.amazon.notreally", strings.NewReader(tt.requestBody)) |
| 85 | + _, err = s.RoundTrip(reqLater) |
| 86 | + require.NoError(t, err) |
| 87 | + require.NotEqual(t, req.Header["Authorization"], reqLater.Header["Authorization"]) |
| 88 | + |
| 89 | + } |
| 90 | + }) |
| 91 | + } |
| 92 | +} |
| 93 | +func Test_getRequestBodyHash(t *testing.T) { |
| 94 | + tests := []struct { |
| 95 | + name string |
| 96 | + body string |
| 97 | + expected string |
| 98 | + }{ |
| 99 | + { |
| 100 | + name: "empty body is empty hash", |
| 101 | + body: "", |
| 102 | + expected: EmptySha256Hash, |
| 103 | + }, |
| 104 | + { |
| 105 | + name: "hello world", |
| 106 | + body: "hello world", |
| 107 | + expected: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", |
| 108 | + }, |
| 109 | + } |
| 110 | + for _, tt := range tests { |
| 111 | + t.Run(tt.name, func(t *testing.T) { |
| 112 | + req, _ := http.NewRequest("get", "https://whatever.wherever:999", strings.NewReader(tt.body)) |
| 113 | + got, _ := getRequestBodyHash(req) |
| 114 | + assert.Equalf(t, tt.expected, got, "getRequestBodyHash(%v)", req) |
| 115 | + }) |
| 116 | + } |
| 117 | +} |
| 118 | + |
| 119 | +type staticClock struct { |
| 120 | + when time.Time |
| 121 | +} |
| 122 | + |
| 123 | +func (s staticClock) Now() time.Time { return s.when } |
| 124 | + |
| 125 | +type testRoundTripper struct { |
| 126 | + seen *http.Request |
| 127 | +} |
| 128 | + |
| 129 | +func (t *testRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { |
| 130 | + t.seen = request |
| 131 | + return &http.Response{Status: "everything is awesome", StatusCode: 200}, nil |
| 132 | +} |
0 commit comments