@@ -11,12 +11,9 @@ import (
11
11
"context"
12
12
"crypto/rand"
13
13
"encoding/base64"
14
- "encoding/json"
15
14
"errors"
16
15
"fmt"
17
- "io/ioutil"
18
16
"net/http"
19
- "os"
20
17
"strings"
21
18
"time"
22
19
@@ -36,35 +33,23 @@ const (
36
33
)
37
34
38
35
type awsConversation struct {
39
- state clientState
40
- valid bool
41
- nonce []byte
42
- username string
43
- password string
44
- token string
45
- httpClient * http.Client
36
+ state clientState
37
+ valid bool
38
+ nonce []byte
39
+ provider interface {
40
+ getCredentials (ctx context.Context ) (* awsv4.StaticProvider , error )
41
+ }
46
42
}
47
43
48
44
type serverMessage struct {
49
45
Nonce primitive.Binary `bson:"s"`
50
46
Host string `bson:"h"`
51
47
}
52
48
53
- type ecsResponse struct {
54
- AccessKeyID string `json:"AccessKeyId"`
55
- SecretAccessKey string `json:"SecretAccessKey"`
56
- Token string `json:"Token"`
57
- }
58
-
59
49
const (
60
50
amzDateFormat = "20060102T150405Z"
61
- awsRelativeURI = "http://169.254.170.2/"
62
- awsEC2URI = "http://169.254.169.254/"
63
- awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
64
- awsEC2TokenPath = "latest/api/token"
65
51
defaultRegion = "us-east-1"
66
52
maxHostLength = 255
67
- defaultHTTPTimeout = 10 * time .Second
68
53
responceNonceLength = 64
69
54
)
70
55
@@ -128,149 +113,6 @@ func getRegion(host string) (string, error) {
128
113
return region , nil
129
114
}
130
115
131
- func (ac * awsConversation ) validateAndMakeCredentials () (* awsv4.StaticProvider , error ) {
132
- if ac .username != "" && ac .password == "" {
133
- return nil , errors .New ("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing" )
134
- }
135
- if ac .username == "" && ac .password != "" {
136
- return nil , errors .New ("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing" )
137
- }
138
- if ac .username == "" && ac .password == "" && ac .token != "" {
139
- return nil , errors .New ("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing" )
140
- }
141
- if ac .username != "" || ac .password != "" || ac .token != "" {
142
- return & awsv4.StaticProvider {Value : awsv4.Value {
143
- AccessKeyID : ac .username ,
144
- SecretAccessKey : ac .password ,
145
- SessionToken : ac .token ,
146
- }}, nil
147
- }
148
- return nil , nil
149
- }
150
-
151
- func executeAWSHTTPRequest (httpClient * http.Client , req * http.Request ) ([]byte , error ) {
152
- ctx , cancel := context .WithTimeout (context .Background (), defaultHTTPTimeout )
153
- defer cancel ()
154
- resp , err := httpClient .Do (req .WithContext (ctx ))
155
- if err != nil {
156
- return nil , err
157
- }
158
- defer resp .Body .Close ()
159
-
160
- return ioutil .ReadAll (resp .Body )
161
- }
162
-
163
- func (ac * awsConversation ) getEC2Credentials () (* awsv4.StaticProvider , error ) {
164
- // get token
165
- req , err := http .NewRequest ("PUT" , awsEC2URI + awsEC2TokenPath , nil )
166
- if err != nil {
167
- return nil , err
168
- }
169
- req .Header .Set ("X-aws-ec2-metadata-token-ttl-seconds" , "30" )
170
-
171
- token , err := executeAWSHTTPRequest (ac .httpClient , req )
172
- if err != nil {
173
- return nil , err
174
- }
175
- if len (token ) == 0 {
176
- return nil , errors .New ("unable to retrieve token from EC2 metadata" )
177
- }
178
- tokenStr := string (token )
179
-
180
- // get role name
181
- req , err = http .NewRequest ("GET" , awsEC2URI + awsEC2RolePath , nil )
182
- if err != nil {
183
- return nil , err
184
- }
185
- req .Header .Set ("X-aws-ec2-metadata-token" , tokenStr )
186
-
187
- role , err := executeAWSHTTPRequest (ac .httpClient , req )
188
- if err != nil {
189
- return nil , err
190
- }
191
- if len (role ) == 0 {
192
- return nil , errors .New ("unable to retrieve role_name from EC2 metadata" )
193
- }
194
-
195
- // get credentials
196
- pathWithRole := awsEC2URI + awsEC2RolePath + string (role )
197
- req , err = http .NewRequest ("GET" , pathWithRole , nil )
198
- if err != nil {
199
- return nil , err
200
- }
201
- req .Header .Set ("X-aws-ec2-metadata-token" , tokenStr )
202
- creds , err := executeAWSHTTPRequest (ac .httpClient , req )
203
- if err != nil {
204
- return nil , err
205
- }
206
-
207
- var es2Resp ecsResponse
208
- err = json .Unmarshal (creds , & es2Resp )
209
- if err != nil {
210
- return nil , err
211
- }
212
- ac .username = es2Resp .AccessKeyID
213
- ac .password = es2Resp .SecretAccessKey
214
- ac .token = es2Resp .Token
215
-
216
- return ac .validateAndMakeCredentials ()
217
- }
218
-
219
- func (ac * awsConversation ) getCredentials () (* awsv4.StaticProvider , error ) {
220
- // Credentials passed through URI
221
- creds , err := ac .validateAndMakeCredentials ()
222
- if creds != nil || err != nil {
223
- return creds , err
224
- }
225
-
226
- // Credentials from environment variables
227
- ac .username = os .Getenv ("AWS_ACCESS_KEY_ID" )
228
- ac .password = os .Getenv ("AWS_SECRET_ACCESS_KEY" )
229
- ac .token = os .Getenv ("AWS_SESSION_TOKEN" )
230
-
231
- creds , err = ac .validateAndMakeCredentials ()
232
- if creds != nil || err != nil {
233
- return creds , err
234
- }
235
-
236
- // Credentials from ECS metadata
237
- relativeEcsURI := os .Getenv ("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" )
238
- if len (relativeEcsURI ) > 0 {
239
- fullURI := awsRelativeURI + relativeEcsURI
240
-
241
- req , err := http .NewRequest ("GET" , fullURI , nil )
242
- if err != nil {
243
- return nil , err
244
- }
245
-
246
- body , err := executeAWSHTTPRequest (ac .httpClient , req )
247
- if err != nil {
248
- return nil , err
249
- }
250
-
251
- var espResp ecsResponse
252
- err = json .Unmarshal (body , & espResp )
253
- if err != nil {
254
- return nil , err
255
- }
256
- ac .username = espResp .AccessKeyID
257
- ac .password = espResp .SecretAccessKey
258
- ac .token = espResp .Token
259
-
260
- creds , err = ac .validateAndMakeCredentials ()
261
- if creds != nil || err != nil {
262
- return creds , err
263
- }
264
- }
265
-
266
- // Credentials from EC2 metadata
267
- creds , err = ac .getEC2Credentials ()
268
- if creds == nil && err == nil {
269
- return nil , errors .New ("unable to get credentials" )
270
- }
271
- return creds , err
272
- }
273
-
274
116
func (ac * awsConversation ) firstMsg () []byte {
275
117
// Values are cached for use in final message parameters
276
118
ac .nonce = make ([]byte , 32 )
@@ -306,7 +148,7 @@ func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
306
148
return nil , err
307
149
}
308
150
309
- creds , err := ac .getCredentials ()
151
+ creds , err := ac .provider . getCredentials (context . Background () )
310
152
if err != nil {
311
153
return nil , err
312
154
}
@@ -320,8 +162,8 @@ func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
320
162
req .Header .Set ("Content-Length" , "43" )
321
163
req .Host = sm .Host
322
164
req .Header .Set ("X-Amz-Date" , currentTime .Format (amzDateFormat ))
323
- if len (ac . token ) > 0 {
324
- req .Header .Set ("X-Amz-Security-Token" , ac . token )
165
+ if len (creds . Value . SessionToken ) > 0 {
166
+ req .Header .Set ("X-Amz-Security-Token" , creds . Value . SessionToken )
325
167
}
326
168
req .Header .Set ("X-MongoDB-Server-Nonce" , base64 .StdEncoding .EncodeToString (sm .Nonce .Data ))
327
169
req .Header .Set ("X-MongoDB-GS2-CB-Flag" , "n" )
@@ -339,8 +181,8 @@ func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
339
181
idx , msg := bsoncore .AppendDocumentStart (nil )
340
182
msg = bsoncore .AppendStringElement (msg , "a" , req .Header .Get ("Authorization" ))
341
183
msg = bsoncore .AppendStringElement (msg , "d" , req .Header .Get ("X-Amz-Date" ))
342
- if len (ac . token ) > 0 {
343
- msg = bsoncore .AppendStringElement (msg , "t" , ac . token )
184
+ if len (creds . Value . SessionToken ) > 0 {
185
+ msg = bsoncore .AppendStringElement (msg , "t" , creds . Value . SessionToken )
344
186
}
345
187
msg , _ = bsoncore .AppendDocumentEnd (msg , idx )
346
188
0 commit comments