@@ -13,7 +13,9 @@ import (
1313 "bytes"
1414 "context"
1515 "crypto/tls"
16+ "crypto/x509"
1617 "encoding/base64"
18+ "encoding/json"
1719 "fmt"
1820 "io/ioutil"
1921 "net"
@@ -30,6 +32,7 @@ import (
3032 "go.mongodb.org/mongo-driver/v2/internal/handshake"
3133 "go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
3234 "go.mongodb.org/mongo-driver/v2/internal/integtest"
35+ "go.mongodb.org/mongo-driver/v2/internal/require"
3336 "go.mongodb.org/mongo-driver/v2/mongo"
3437 "go.mongodb.org/mongo-driver/v2/mongo/options"
3538 "go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
@@ -2918,7 +2921,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
29182921 }
29192922 })
29202923
2921- mt .RunOpts ("22 . range explicit encryption applies defaults" , qeRunOpts22 , func (mt * mtest.T ) {
2924+ mt .RunOpts ("23 . range explicit encryption applies defaults" , qeRunOpts22 , func (mt * mtest.T ) {
29222925 err := mt .Client .Database ("keyvault" ).Collection ("datakeys" ).Drop (context .Background ())
29232926 assert .Nil (mt , err , "error on Drop: %v" , err )
29242927
@@ -2979,6 +2982,67 @@ func TestClientSideEncryptionProse(t *testing.T) {
29792982 assert .Greater (t , len (payload .Data ), len (payloadDefaults .Data ), "the returned payload size is expected to be greater than %d" , len (payloadDefaults .Data ))
29802983 })
29812984 })
2985+
2986+ mt .RunOpts ("24. KMS Retry Tests" , qeRunOpts22 , func (mt * mtest.T ) {
2987+ setFailPoint := func (failure string , count int ) error {
2988+ url := fmt .Sprintf ("https://localhost:9003/set_failpoint/%s" , failure )
2989+ var payloadBuf bytes.Buffer
2990+ body := map [string ]int {"count" : count }
2991+ json .NewEncoder (& payloadBuf ).Encode (body )
2992+ req , err := http .NewRequest (http .MethodPost , url , & payloadBuf )
2993+ if err != nil {
2994+ return err
2995+ }
2996+
2997+ cert , err := ioutil .ReadFile (os .Getenv ("CSFLE_TLS_CA_FILE" ))
2998+ if err != nil {
2999+ return err
3000+ }
3001+
3002+ certPool := x509 .NewCertPool ()
3003+ certPool .AppendCertsFromPEM (cert )
3004+
3005+ client := & http.Client {
3006+ Transport : & http.Transport {
3007+ TLSClientConfig : & tls.Config {
3008+ RootCAs : certPool ,
3009+ },
3010+ },
3011+ }
3012+ _ , err = client .Do (req )
3013+ return err
3014+ }
3015+
3016+ keyVaultClient , err := mongo .Connect (options .Client ().ApplyURI (mtest .ClusterURI ()))
3017+ require .NoError (mt , err , "error on Connect: %v" , err )
3018+
3019+ ceo := options .ClientEncryption ().
3020+ SetKeyVaultNamespace ("keyvault.datakeys" ).
3021+ SetKmsProviders (fullKmsProvidersMap )
3022+ clientEncryption , err := mongo .NewClientEncryption (keyVaultClient , ceo )
3023+ require .NoError (mt , err , "error on NewClientEncryption: %v" , err )
3024+
3025+ err = setFailPoint ("http" , 1 )
3026+ require .NoError (mt , err , "mock server error: %v" , err )
3027+
3028+ dkOpts := options .DataKey ().SetMasterKey (
3029+ bson.D {
3030+ {"region" , "foo" },
3031+ {"key" , "bar" },
3032+ {"endpoint" , "127.0.0.1:9003" },
3033+ },
3034+ )
3035+ var keyID bson.Binary
3036+ keyID , err = clientEncryption .CreateDataKey (context .Background (), "aws" , dkOpts )
3037+ require .NoError (mt , err , "error in CreateDataKey: %v" , err )
3038+
3039+ testVal := bson.RawValue {Type : bson .TypeInt32 , Value : bsoncore .AppendInt32 (nil , 123 )}
3040+ eo := options .Encrypt ().
3041+ SetKeyID (keyID ).
3042+ SetAlgorithm ("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" )
3043+ _ , err = clientEncryption .Encrypt (context .Background (), testVal , eo )
3044+ assert .NoError (mt , err , "error in Encrypt: %v" , err )
3045+ })
29823046}
29833047
29843048func getWatcher (mt * mtest.T , streamType mongo.StreamType , cpt * cseProseTest ) watcher {
0 commit comments