@@ -14,6 +14,7 @@ import (
14
14
"context"
15
15
"crypto/tls"
16
16
"encoding/base64"
17
+ "encoding/json"
17
18
"fmt"
18
19
"io/ioutil"
19
20
"net"
@@ -30,6 +31,7 @@ import (
30
31
"go.mongodb.org/mongo-driver/v2/internal/handshake"
31
32
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
32
33
"go.mongodb.org/mongo-driver/v2/internal/integtest"
34
+ "go.mongodb.org/mongo-driver/v2/internal/require"
33
35
"go.mongodb.org/mongo-driver/v2/mongo"
34
36
"go.mongodb.org/mongo-driver/v2/mongo/options"
35
37
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
@@ -2925,7 +2927,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
2925
2927
}
2926
2928
})
2927
2929
2928
- mt .RunOpts ("22 . range explicit encryption applies defaults" , qeRunOpts22 , func (mt * mtest.T ) {
2930
+ mt .RunOpts ("23 . range explicit encryption applies defaults" , qeRunOpts22 , func (mt * mtest.T ) {
2929
2931
err := mt .Client .Database ("keyvault" ).Collection ("datakeys" ).Drop (context .Background ())
2930
2932
assert .Nil (mt , err , "error on Drop: %v" , err )
2931
2933
@@ -2986,6 +2988,147 @@ func TestClientSideEncryptionProse(t *testing.T) {
2986
2988
assert .Greater (t , len (payload .Data ), len (payloadDefaults .Data ), "the returned payload size is expected to be greater than %d" , len (payloadDefaults .Data ))
2987
2989
})
2988
2990
})
2991
+
2992
+ mt .RunOpts ("24. kms retry tests" , noClientOpts , func (mt * mtest.T ) {
2993
+ kmsTlsTestcase := os .Getenv ("KMS_FAILPOINT_SERVER_RUNNING" )
2994
+ if kmsTlsTestcase == "" {
2995
+ mt .Skipf ("Skipping test as KMS_FAILPOINT_SERVER_RUNNING is not set" )
2996
+ }
2997
+
2998
+ mt .Parallel ()
2999
+
3000
+ tlsCAFile := os .Getenv ("KMS_FAILPOINT_CA_FILE" )
3001
+ require .NotEqual (mt , tlsCAFile , "" , "failed to load CA file" )
3002
+
3003
+ clientAndCATlsMap := map [string ]interface {}{
3004
+ "tlsCAFile" : tlsCAFile ,
3005
+ }
3006
+ tlsCfg , err := options .BuildTLSConfig (clientAndCATlsMap )
3007
+ require .NoError (mt , err , "BuildTLSConfig error: %v" , err )
3008
+
3009
+ setFailPoint := func (failure string , count int ) error {
3010
+ url := fmt .Sprintf ("https://localhost:9003/set_failpoint/%s" , failure )
3011
+ var payloadBuf bytes.Buffer
3012
+ body := map [string ]int {"count" : count }
3013
+ json .NewEncoder (& payloadBuf ).Encode (body )
3014
+ req , err := http .NewRequest (http .MethodPost , url , & payloadBuf )
3015
+ if err != nil {
3016
+ return err
3017
+ }
3018
+
3019
+ client := & http.Client {
3020
+ Transport : & http.Transport {TLSClientConfig : tlsCfg },
3021
+ }
3022
+ res , err := client .Do (req )
3023
+ if err != nil {
3024
+ return err
3025
+ }
3026
+ return res .Body .Close ()
3027
+ }
3028
+
3029
+ kmsProviders := map [string ]map [string ]interface {}{
3030
+ "aws" : {
3031
+ "accessKeyId" : awsAccessKeyID ,
3032
+ "secretAccessKey" : awsSecretAccessKey ,
3033
+ },
3034
+ "azure" : {
3035
+ "tenantId" : azureTenantID ,
3036
+ "clientId" : azureClientID ,
3037
+ "clientSecret" : azureClientSecret ,
3038
+ "identityPlatformEndpoint" : "127.0.0.1:9003" ,
3039
+ },
3040
+ "gcp" : {
3041
+ "email" : gcpEmail ,
3042
+ "privateKey" : gcpPrivateKey ,
3043
+ "endpoint" : "127.0.0.1:9003" ,
3044
+ },
3045
+ }
3046
+
3047
+ dataKeys := []struct {
3048
+ provider string
3049
+ masterKey interface {}
3050
+ }{
3051
+ {"aws" , bson.D {
3052
+ {"region" , "foo" },
3053
+ {"key" , "bar" },
3054
+ {"endpoint" , "127.0.0.1:9003" },
3055
+ }},
3056
+ {"azure" , bson.D {
3057
+ {"keyVaultEndpoint" , "127.0.0.1:9003" },
3058
+ {"keyName" , "foo" },
3059
+ }},
3060
+ {"gcp" , bson.D {
3061
+ {"projectId" , "foo" },
3062
+ {"location" , "bar" },
3063
+ {"keyRing" , "baz" },
3064
+ {"keyName" , "qux" },
3065
+ {"endpoint" , "127.0.0.1:9003" },
3066
+ }},
3067
+ }
3068
+
3069
+ testCases := []struct {
3070
+ name string
3071
+ failure string
3072
+ }{
3073
+ {"Case 1: createDataKey and encrypt with TCP retry" , "network" },
3074
+ {"Case 2: createDataKey and encrypt with HTTP retry" , "http" },
3075
+ }
3076
+
3077
+ for _ , tc := range testCases {
3078
+ for _ , dataKey := range dataKeys {
3079
+ mt .Run (fmt .Sprintf ("%s_%s" , tc .name , dataKey .provider ), func (mt * mtest.T ) {
3080
+ keyVaultClient , err := mongo .Connect (options .Client ().ApplyURI (mtest .ClusterURI ()))
3081
+ require .NoError (mt , err , "error on Connect: %v" , err )
3082
+
3083
+ ceo := options .ClientEncryption ().
3084
+ SetKeyVaultNamespace (kvNamespace ).
3085
+ SetKmsProviders (kmsProviders ).
3086
+ SetTLSConfig (map [string ]* tls.Config {dataKey .provider : tlsCfg })
3087
+ clientEncryption , err := mongo .NewClientEncryption (keyVaultClient , ceo )
3088
+ require .NoError (mt , err , "error on NewClientEncryption: %v" , err )
3089
+
3090
+ err = setFailPoint (tc .failure , 1 )
3091
+ require .NoError (mt , err , "mock server error: %v" , err )
3092
+
3093
+ dkOpts := options .DataKey ().SetMasterKey (dataKey .masterKey )
3094
+ var keyID bson.Binary
3095
+ keyID , err = clientEncryption .CreateDataKey (context .Background (), dataKey .provider , dkOpts )
3096
+ require .NoError (mt , err , "error in CreateDataKey: %v" , err )
3097
+
3098
+ err = setFailPoint (tc .failure , 1 )
3099
+ require .NoError (mt , err , "mock server error: %v" , err )
3100
+
3101
+ testVal := bson.RawValue {Type : bson .TypeInt32 , Value : bsoncore .AppendInt32 (nil , 123 )}
3102
+ eo := options .Encrypt ().
3103
+ SetKeyID (keyID ).
3104
+ SetAlgorithm ("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" )
3105
+ _ , err = clientEncryption .Encrypt (context .Background (), testVal , eo )
3106
+ require .NoError (mt , err , "error in Encrypt: %v" , err )
3107
+ })
3108
+ }
3109
+ }
3110
+
3111
+ for _ , dataKey := range dataKeys {
3112
+ mt .Run (fmt .Sprintf ("Case 3: createDataKey fails after too many retries_%s" , dataKey .provider ), func (mt * mtest.T ) {
3113
+ keyVaultClient , err := mongo .Connect (options .Client ().ApplyURI (mtest .ClusterURI ()))
3114
+ require .NoError (mt , err , "error on Connect: %v" , err )
3115
+
3116
+ ceo := options .ClientEncryption ().
3117
+ SetKeyVaultNamespace (kvNamespace ).
3118
+ SetKmsProviders (kmsProviders ).
3119
+ SetTLSConfig (map [string ]* tls.Config {dataKey .provider : tlsCfg })
3120
+ clientEncryption , err := mongo .NewClientEncryption (keyVaultClient , ceo )
3121
+ require .NoError (mt , err , "error on NewClientEncryption: %v" , err )
3122
+
3123
+ err = setFailPoint ("network" , 4 )
3124
+ require .NoError (mt , err , "mock server error: %v" , err )
3125
+
3126
+ dkOpts := options .DataKey ().SetMasterKey (dataKey .masterKey )
3127
+ _ , err = clientEncryption .CreateDataKey (context .Background (), dataKey .provider , dkOpts )
3128
+ require .ErrorContains (mt , err , "KMS request failed after 3 retries due to a network error" )
3129
+ })
3130
+ }
3131
+ })
2989
3132
}
2990
3133
2991
3134
func getWatcher (mt * mtest.T , streamType mongo.StreamType , cpt * cseProseTest ) watcher {
0 commit comments