Skip to content

Commit 0cbab0b

Browse files
authored
GODRIVER-2410 Obtain AWS credentials for CSFLE in the same way as for MONGODB-AWS. (#1143)
1 parent 13ffe8d commit 0cbab0b

File tree

10 files changed

+468
-247
lines changed

10 files changed

+468
-247
lines changed

.evergreen/config.yml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2028,6 +2028,50 @@ tasks:
20282028
EXPECT_ERROR='unable to retrieve GCP credentials' \
20292029
./testgcpkms
20302030
2031+
- name: "testawskms-task"
2032+
commands:
2033+
- command: shell.exec
2034+
type: test
2035+
params:
2036+
working_dir: src/go.mongodb.org/mongo-driver
2037+
shell: "bash"
2038+
script: |
2039+
${PREPARE_SHELL}
2040+
echo "Building build-awskms-test ... begin"
2041+
BUILD_TAGS="-tags cse" \
2042+
PKG_CONFIG_PATH=$PKG_CONFIG_PATH \
2043+
make build-awskms-test
2044+
echo "Building build-awskms-test ... end"
2045+
2046+
export AWS_ACCESS_KEY_ID="${cse_aws_access_key_id}"
2047+
export AWS_SECRET_ACCESS_KEY="${cse_aws_secret_access_key}"
2048+
2049+
LD_LIBRARY_PATH=./install/libmongocrypt/lib \
2050+
MONGODB_URI='${atlas_free_tier_uri}' \
2051+
./testawskms
2052+
2053+
- name: "testawskms-fail-task"
2054+
# testawskms-fail-task runs without environment variables.
2055+
# It is expected to fail to obtain credentials.
2056+
commands:
2057+
- command: shell.exec
2058+
type: test
2059+
params:
2060+
working_dir: src/go.mongodb.org/mongo-driver
2061+
shell: "bash"
2062+
script: |
2063+
${PREPARE_SHELL}
2064+
echo "Building build-awskms-test ... begin"
2065+
BUILD_TAGS="-tags cse" \
2066+
PKG_CONFIG_PATH=$PKG_CONFIG_PATH \
2067+
make build-awskms-test
2068+
echo "Building build-awskms-test ... end"
2069+
2070+
LD_LIBRARY_PATH=./install/libmongocrypt/lib \
2071+
MONGODB_URI='${atlas_free_tier_uri}' \
2072+
EXPECT_ERROR='unable to retrieve AWS credentials' \
2073+
./testawskms
2074+
20312075
- name: "test-fuzz"
20322076
commands:
20332077
- func: bootstrap-mongo-orchestration
@@ -2444,3 +2488,13 @@ buildvariants:
24442488
- name: testgcpkms_task_group
24452489
batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README
24462490
- testgcpkms-fail-task
2491+
2492+
- name: testawskms-variant
2493+
display_name: "AWS KMS"
2494+
run_on:
2495+
- debian11-small
2496+
expansions:
2497+
GO_DIST: "/opt/golang/go1.18"
2498+
tasks:
2499+
- testawskms-task
2500+
- testawskms-fail-task

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ evg-test-versioned-api:
174174
build-gcpkms-test:
175175
go build $(BUILD_TAGS) ./cmd/testgcpkms
176176

177+
.PHONY: build-awskms-test
178+
build-awskms-test:
179+
go build $(BUILD_TAGS) ./cmd/testawskms
180+
177181
### Benchmark specific targets and support. ###
178182
.PHONY: benchmark
179183
benchmark:perf

cmd/testawskms/main.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (C) MongoDB, Inc. 2022-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package main
8+
9+
import (
10+
"context"
11+
"fmt"
12+
"os"
13+
"strings"
14+
15+
"go.mongodb.org/mongo-driver/bson"
16+
"go.mongodb.org/mongo-driver/mongo"
17+
"go.mongodb.org/mongo-driver/mongo/options"
18+
)
19+
20+
func main() {
21+
uri := os.Getenv("MONGODB_URI")
22+
// expecterror is an expect error substring. Set to empty string to expect no error.
23+
expecterror := os.Getenv("EXPECT_ERROR")
24+
25+
if uri == "" {
26+
fmt.Println("ERROR: Please set required MONGODB_URI environment variable.")
27+
fmt.Println("The following environment variables are understood:")
28+
fmt.Println("- MONGODB_URI as a MongoDB URI. Example: 'mongodb://localhost:27017'")
29+
fmt.Println("- EXPECT_ERROR as an optional expected error substring.")
30+
os.Exit(1)
31+
}
32+
33+
cOpts := options.Client().ApplyURI(uri)
34+
keyVaultClient, err := mongo.Connect(context.Background(), cOpts)
35+
if err != nil {
36+
panic(fmt.Sprintf("Connect error: %v", err))
37+
}
38+
defer keyVaultClient.Disconnect(context.Background())
39+
40+
kmsProvidersMap := map[string]map[string]interface{}{
41+
"aws": {},
42+
}
43+
ceOpts := options.ClientEncryption().SetKmsProviders(kmsProvidersMap).SetKeyVaultNamespace("keyvault.datakeys")
44+
ce, err := mongo.NewClientEncryption(keyVaultClient, ceOpts)
45+
if err != nil {
46+
panic(fmt.Sprintf("Error in NewClientEncryption: %v", err))
47+
}
48+
defer ce.Close(context.Background())
49+
50+
dkOpts := options.DataKey().SetMasterKey(bson.M{
51+
"region": "us-east-1",
52+
"key": "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
53+
})
54+
_, err = ce.CreateDataKey(context.Background(), "aws", dkOpts)
55+
if expecterror == "" {
56+
if err != nil {
57+
panic(fmt.Sprintf("Expected success, but got error in CreateDataKey: %v", err))
58+
}
59+
} else {
60+
if err == nil {
61+
panic(fmt.Sprintf("Expected error message to contain %q, but got no error", expecterror))
62+
}
63+
if !strings.Contains(err.Error(), expecterror) {
64+
panic(fmt.Sprintf("Expected error message to contain %q, but got %q", expecterror, err.Error()))
65+
}
66+
}
67+
}

x/mongo/driver/auth/aws_conv.go

Lines changed: 11 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@ import (
1111
"context"
1212
"crypto/rand"
1313
"encoding/base64"
14-
"encoding/json"
1514
"errors"
1615
"fmt"
17-
"io/ioutil"
1816
"net/http"
19-
"os"
2017
"strings"
2118
"time"
2219

@@ -36,35 +33,23 @@ const (
3633
)
3734

3835
type awsConversation struct {
39-
state clientState
40-
valid bool
41-
nonce []byte
42-
username string
43-
password string
44-
token string
45-
httpClient *http.Client
36+
state clientState
37+
valid bool
38+
nonce []byte
39+
provider interface {
40+
getCredentials(ctx context.Context) (*awsv4.StaticProvider, error)
41+
}
4642
}
4743

4844
type serverMessage struct {
4945
Nonce primitive.Binary `bson:"s"`
5046
Host string `bson:"h"`
5147
}
5248

53-
type ecsResponse struct {
54-
AccessKeyID string `json:"AccessKeyId"`
55-
SecretAccessKey string `json:"SecretAccessKey"`
56-
Token string `json:"Token"`
57-
}
58-
5949
const (
6050
amzDateFormat = "20060102T150405Z"
61-
awsRelativeURI = "http://169.254.170.2/"
62-
awsEC2URI = "http://169.254.169.254/"
63-
awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
64-
awsEC2TokenPath = "latest/api/token"
6551
defaultRegion = "us-east-1"
6652
maxHostLength = 255
67-
defaultHTTPTimeout = 10 * time.Second
6853
responceNonceLength = 64
6954
)
7055

@@ -128,149 +113,6 @@ func getRegion(host string) (string, error) {
128113
return region, nil
129114
}
130115

131-
func (ac *awsConversation) validateAndMakeCredentials() (*awsv4.StaticProvider, error) {
132-
if ac.username != "" && ac.password == "" {
133-
return nil, errors.New("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing")
134-
}
135-
if ac.username == "" && ac.password != "" {
136-
return nil, errors.New("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing")
137-
}
138-
if ac.username == "" && ac.password == "" && ac.token != "" {
139-
return nil, errors.New("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing")
140-
}
141-
if ac.username != "" || ac.password != "" || ac.token != "" {
142-
return &awsv4.StaticProvider{Value: awsv4.Value{
143-
AccessKeyID: ac.username,
144-
SecretAccessKey: ac.password,
145-
SessionToken: ac.token,
146-
}}, nil
147-
}
148-
return nil, nil
149-
}
150-
151-
func executeAWSHTTPRequest(httpClient *http.Client, req *http.Request) ([]byte, error) {
152-
ctx, cancel := context.WithTimeout(context.Background(), defaultHTTPTimeout)
153-
defer cancel()
154-
resp, err := httpClient.Do(req.WithContext(ctx))
155-
if err != nil {
156-
return nil, err
157-
}
158-
defer resp.Body.Close()
159-
160-
return ioutil.ReadAll(resp.Body)
161-
}
162-
163-
func (ac *awsConversation) getEC2Credentials() (*awsv4.StaticProvider, error) {
164-
// get token
165-
req, err := http.NewRequest("PUT", awsEC2URI+awsEC2TokenPath, nil)
166-
if err != nil {
167-
return nil, err
168-
}
169-
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "30")
170-
171-
token, err := executeAWSHTTPRequest(ac.httpClient, req)
172-
if err != nil {
173-
return nil, err
174-
}
175-
if len(token) == 0 {
176-
return nil, errors.New("unable to retrieve token from EC2 metadata")
177-
}
178-
tokenStr := string(token)
179-
180-
// get role name
181-
req, err = http.NewRequest("GET", awsEC2URI+awsEC2RolePath, nil)
182-
if err != nil {
183-
return nil, err
184-
}
185-
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
186-
187-
role, err := executeAWSHTTPRequest(ac.httpClient, req)
188-
if err != nil {
189-
return nil, err
190-
}
191-
if len(role) == 0 {
192-
return nil, errors.New("unable to retrieve role_name from EC2 metadata")
193-
}
194-
195-
// get credentials
196-
pathWithRole := awsEC2URI + awsEC2RolePath + string(role)
197-
req, err = http.NewRequest("GET", pathWithRole, nil)
198-
if err != nil {
199-
return nil, err
200-
}
201-
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
202-
creds, err := executeAWSHTTPRequest(ac.httpClient, req)
203-
if err != nil {
204-
return nil, err
205-
}
206-
207-
var es2Resp ecsResponse
208-
err = json.Unmarshal(creds, &es2Resp)
209-
if err != nil {
210-
return nil, err
211-
}
212-
ac.username = es2Resp.AccessKeyID
213-
ac.password = es2Resp.SecretAccessKey
214-
ac.token = es2Resp.Token
215-
216-
return ac.validateAndMakeCredentials()
217-
}
218-
219-
func (ac *awsConversation) getCredentials() (*awsv4.StaticProvider, error) {
220-
// Credentials passed through URI
221-
creds, err := ac.validateAndMakeCredentials()
222-
if creds != nil || err != nil {
223-
return creds, err
224-
}
225-
226-
// Credentials from environment variables
227-
ac.username = os.Getenv("AWS_ACCESS_KEY_ID")
228-
ac.password = os.Getenv("AWS_SECRET_ACCESS_KEY")
229-
ac.token = os.Getenv("AWS_SESSION_TOKEN")
230-
231-
creds, err = ac.validateAndMakeCredentials()
232-
if creds != nil || err != nil {
233-
return creds, err
234-
}
235-
236-
// Credentials from ECS metadata
237-
relativeEcsURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
238-
if len(relativeEcsURI) > 0 {
239-
fullURI := awsRelativeURI + relativeEcsURI
240-
241-
req, err := http.NewRequest("GET", fullURI, nil)
242-
if err != nil {
243-
return nil, err
244-
}
245-
246-
body, err := executeAWSHTTPRequest(ac.httpClient, req)
247-
if err != nil {
248-
return nil, err
249-
}
250-
251-
var espResp ecsResponse
252-
err = json.Unmarshal(body, &espResp)
253-
if err != nil {
254-
return nil, err
255-
}
256-
ac.username = espResp.AccessKeyID
257-
ac.password = espResp.SecretAccessKey
258-
ac.token = espResp.Token
259-
260-
creds, err = ac.validateAndMakeCredentials()
261-
if creds != nil || err != nil {
262-
return creds, err
263-
}
264-
}
265-
266-
// Credentials from EC2 metadata
267-
creds, err = ac.getEC2Credentials()
268-
if creds == nil && err == nil {
269-
return nil, errors.New("unable to get credentials")
270-
}
271-
return creds, err
272-
}
273-
274116
func (ac *awsConversation) firstMsg() []byte {
275117
// Values are cached for use in final message parameters
276118
ac.nonce = make([]byte, 32)
@@ -306,7 +148,7 @@ func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
306148
return nil, err
307149
}
308150

309-
creds, err := ac.getCredentials()
151+
creds, err := ac.provider.getCredentials(context.Background())
310152
if err != nil {
311153
return nil, err
312154
}
@@ -320,8 +162,8 @@ func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
320162
req.Header.Set("Content-Length", "43")
321163
req.Host = sm.Host
322164
req.Header.Set("X-Amz-Date", currentTime.Format(amzDateFormat))
323-
if len(ac.token) > 0 {
324-
req.Header.Set("X-Amz-Security-Token", ac.token)
165+
if len(creds.Value.SessionToken) > 0 {
166+
req.Header.Set("X-Amz-Security-Token", creds.Value.SessionToken)
325167
}
326168
req.Header.Set("X-MongoDB-Server-Nonce", base64.StdEncoding.EncodeToString(sm.Nonce.Data))
327169
req.Header.Set("X-MongoDB-GS2-CB-Flag", "n")
@@ -339,8 +181,8 @@ func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
339181
idx, msg := bsoncore.AppendDocumentStart(nil)
340182
msg = bsoncore.AppendStringElement(msg, "a", req.Header.Get("Authorization"))
341183
msg = bsoncore.AppendStringElement(msg, "d", req.Header.Get("X-Amz-Date"))
342-
if len(ac.token) > 0 {
343-
msg = bsoncore.AppendStringElement(msg, "t", ac.token)
184+
if len(creds.Value.SessionToken) > 0 {
185+
msg = bsoncore.AppendStringElement(msg, "t", creds.Value.SessionToken)
344186
}
345187
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
346188

0 commit comments

Comments
 (0)