Skip to content

Commit 9802dda

Browse files
committed
update test case
1 parent a23e594 commit 9802dda

File tree

3 files changed

+44
-36
lines changed

3 files changed

+44
-36
lines changed

internal/integration/client_side_encryption_prose_test.go

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3148,6 +3148,8 @@ func TestClientSideEncryptionProse(t *testing.T) {
31483148
})
31493149

31503150
mt.RunOpts("26. custom AWS credentials", qeRunOpts22, func(mt *mtest.T) {
3151+
provider := credproviders.NewEnvProvider()
3152+
31513153
mt.Run("Case 1: ClientEncryption with credentialProviders and incorrect kmsProviders", func(mt *mtest.T) {
31523154
opts := options.Client().ApplyURI(mtest.ClusterURI())
31533155
integtest.AddTestServerAPIVersion(opts)
@@ -3165,7 +3167,6 @@ func TestClientSideEncryptionProse(t *testing.T) {
31653167
SetCredentialProviders(map[string]options.CredentialsProvider{
31663168
"aws": func(ctx context.Context) (options.Credentials, error) {
31673169
var cred options.Credentials
3168-
provider := credproviders.NewEnvProvider()
31693170
c, err := provider.Retrieve(ctx)
31703171
if err != nil {
31713172
return cred, err
@@ -3177,12 +3178,9 @@ func TestClientSideEncryptionProse(t *testing.T) {
31773178
return cred, nil
31783179
},
31793180
})
3180-
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3181-
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
3182-
3183-
dkOpts := options.DataKey()
3184-
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
3185-
assert.Error(mt, err, "expected an error")
3181+
_, err = mongo.NewClientEncryption(keyVaultClient, ceo)
3182+
assert.ErrorContains(mt, err, "can only provide a custom AWS credential provider",
3183+
"unexpected error: %v", err)
31863184
})
31873185
mt.Run("Case 2: ClientEncryption with credentialProviders works", func(mt *mtest.T) {
31883186
opts := options.Client().ApplyURI(mtest.ClusterURI())
@@ -3209,7 +3207,10 @@ func TestClientSideEncryptionProse(t *testing.T) {
32093207
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
32103208
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
32113209

3212-
dkOpts := options.DataKey()
3210+
dkOpts := options.DataKey().SetMasterKey(bson.D{
3211+
{"region", "us-east-1"},
3212+
{"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"},
3213+
})
32133214
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
32143215
assert.NoErrorf(mt, err, "unexpected error %v", err)
32153216
assert.Equal(mt, 1, calledCount, "expected credential provider to be called once")
@@ -3227,7 +3228,6 @@ func TestClientSideEncryptionProse(t *testing.T) {
32273228
SetCredentialProviders(map[string]options.CredentialsProvider{
32283229
"aws": func(ctx context.Context) (options.Credentials, error) {
32293230
var cred options.Credentials
3230-
provider := credproviders.NewEnvProvider()
32313231
c, err := provider.Retrieve(ctx)
32323232
if err != nil {
32333233
return cred, err
@@ -3242,47 +3242,45 @@ func TestClientSideEncryptionProse(t *testing.T) {
32423242
co := options.Client().SetAutoEncryptionOptions(aeo).ApplyURI(mtest.ClusterURI())
32433243
integtest.AddTestServerAPIVersion(co)
32443244
_, err := mongo.Connect(co)
3245-
assert.Error(mt, err, "expected an error")
3245+
assert.ErrorContainsf(mt, err, "can only provide a custom AWS credential provider",
3246+
"unexpected error: %v", err)
32463247
})
32473248

32483249
mt.Run("Case 4: ClientEncryption with credentialProviders and valid environment variables", func(mt *mtest.T) {
3249-
mt.Setenv("AWS_ACCESS_KEY_ID", os.Getenv("FLE_AWS_SECRET_ACCESS_KEY"))
3250-
mt.Setenv("AWS_SECRET_ACCESS_KEY", os.Getenv("FLE_AWS_ACCESS_KEY_ID"))
3250+
// mt.Setenv("AWS_ACCESS_KEY_ID", os.Getenv("FLE_AWS_SECRET_ACCESS_KEY"))
3251+
// mt.Setenv("AWS_SECRET_ACCESS_KEY", os.Getenv("FLE_AWS_ACCESS_KEY_ID"))
32513252

32523253
opts := options.Client().ApplyURI(mtest.ClusterURI())
32533254
integtest.AddTestServerAPIVersion(opts)
32543255
keyVaultClient, err := mongo.Connect(opts)
32553256
assert.NoErrorf(mt, err, "error on Connect: %v", err)
32563257

3258+
var calledCount int
32573259
ceo := options.ClientEncryption().
32583260
SetKeyVaultNamespace("keyvault.datakeys").
32593261
SetKmsProviders(map[string]map[string]any{
3260-
"aws": {
3261-
"accessKeyId": awsAccessKeyID,
3262-
"secretAccessKey": awsSecretAccessKey,
3263-
},
3262+
"aws": map[string]any{},
32643263
}).
32653264
SetCredentialProviders(map[string]options.CredentialsProvider{
32663265
"aws": func(ctx context.Context) (options.Credentials, error) {
3267-
var cred options.Credentials
3268-
provider := credproviders.NewEnvProvider()
3269-
c, err := provider.Retrieve(ctx)
3270-
if err != nil {
3271-
return cred, err
3272-
}
3273-
cred.AccessKeyID = c.AccessKeyID
3274-
cred.SecretAccessKey = c.SecretAccessKey
3275-
cred.SessionToken = c.SessionToken
3276-
cred.ExpirationCallback = provider.IsExpired
3277-
return cred, nil
3266+
calledCount++
3267+
return options.Credentials{
3268+
AccessKeyID: awsAccessKeyID,
3269+
SecretAccessKey: awsSecretAccessKey,
3270+
ExpirationCallback: func() bool { return false },
3271+
}, nil
32783272
},
32793273
})
32803274
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
32813275
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
32823276

3283-
dkOpts := options.DataKey()
3277+
dkOpts := options.DataKey().SetMasterKey(bson.D{
3278+
{"region", "us-east-1"},
3279+
{"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"},
3280+
})
32843281
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
32853282
assert.NoErrorf(mt, err, "unexpected error %v", err)
3283+
assert.Equal(mt, 1, calledCount, "expected credential provider to be called once")
32863284
})
32873285
})
32883286
}

internal/test/aws/aws_test.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414

1515
"go.mongodb.org/mongo-driver/v2/bson"
16+
"go.mongodb.org/mongo-driver/v2/internal/credproviders"
1617
"go.mongodb.org/mongo-driver/v2/internal/require"
1718
"go.mongodb.org/mongo-driver/v2/mongo"
1819
"go.mongodb.org/mongo-driver/v2/mongo/options"
@@ -46,15 +47,20 @@ func TestAWSCustomCredentialProviders(t *testing.T) {
4647
}
4748

4849
var calledCount int
50+
provider := credproviders.NewEnvProvider()
4951
awsCredential := options.Credential{
50-
AuthMechanism: "MONGODB-AWS",
51-
AwsCredentialsProvider: func(_ context.Context) (options.Credentials, error) {
52+
AwsCredentialsProvider: func(ctx context.Context) (options.Credentials, error) {
5253
calledCount++
53-
return options.Credentials{
54-
AccessKeyID: os.Getenv("AWS_ACCESS_KEY_ID"),
55-
SecretAccessKey: os.Getenv("AWS_SECRET_ACCESS_KEY"),
56-
ExpirationCallback: func() bool { return false },
57-
}, nil
54+
var creds options.Credentials
55+
value, err := provider.Retrieve(ctx)
56+
if err != nil {
57+
return creds, err
58+
}
59+
creds.AccessKeyID = value.AccessKeyID
60+
creds.SecretAccessKey = value.SecretAccessKey
61+
creds.SessionToken = value.SessionToken
62+
creds.ExpirationCallback = provider.IsExpired
63+
return creds, nil
5864
},
5965
}
6066
client, err := mongo.Connect(options.Client().ApplyURI(uri).SetAuth(awsCredential))

x/mongo/driver/mongocrypt/mongocrypt.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,16 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) {
6565
if needsKmsProvider(opts.KmsProviders, "gcp") {
6666
kmsProviders["gcp"] = creds.NewGCPCredentialProvider(httpClient)
6767
}
68+
provider, ok := opts.CredentialProviders["aws"]
6869
if needsKmsProvider(opts.KmsProviders, "aws") {
6970
var providers []credentials.Provider
70-
if provider, ok := opts.CredentialProviders["aws"]; ok {
71+
if ok {
7172
providers = append(providers, provider)
7273
}
7374
kmsProviders["aws"] = creds.NewAWSCredentialProvider(httpClient, providers...)
75+
} else if ok {
76+
return nil, fmt.Errorf("can only provide a custom AWS credential provider " +
77+
"when the state machine is configured for automatic AWS credential fetching")
7478
}
7579
if needsKmsProvider(opts.KmsProviders, "azure") {
7680
kmsProviders["azure"] = creds.NewAzureCredentialProvider(httpClient)

0 commit comments

Comments
 (0)