Skip to content

Commit b11c2a5

Browse files
committed
GODRIVER-3168 Retry KMS requests on transient errors.
1 parent cf0348c commit b11c2a5

File tree

1 file changed

+74
-1
lines changed

1 file changed

+74
-1
lines changed

internal/integration/client_side_encryption_prose_test.go

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@ import (
1313
"bytes"
1414
"context"
1515
"crypto/tls"
16+
"crypto/x509"
1617
"encoding/base64"
18+
"encoding/json"
1719
"fmt"
1820
"io/ioutil"
21+
"log"
1922
"net"
2023
"net/http"
2124
"os"
@@ -30,6 +33,7 @@ import (
3033
"go.mongodb.org/mongo-driver/v2/internal/handshake"
3134
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
3235
"go.mongodb.org/mongo-driver/v2/internal/integtest"
36+
"go.mongodb.org/mongo-driver/v2/internal/require"
3337
"go.mongodb.org/mongo-driver/v2/mongo"
3438
"go.mongodb.org/mongo-driver/v2/mongo/options"
3539
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
@@ -2918,7 +2922,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
29182922
}
29192923
})
29202924

2921-
mt.RunOpts("22. range explicit encryption applies defaults", qeRunOpts22, func(mt *mtest.T) {
2925+
mt.RunOpts("23. range explicit encryption applies defaults", qeRunOpts22, func(mt *mtest.T) {
29222926
err := mt.Client.Database("keyvault").Collection("datakeys").Drop(context.Background())
29232927
assert.Nil(mt, err, "error on Drop: %v", err)
29242928

@@ -2979,6 +2983,75 @@ func TestClientSideEncryptionProse(t *testing.T) {
29792983
assert.Greater(t, len(payload.Data), len(payloadDefaults.Data), "the returned payload size is expected to be greater than %d", len(payloadDefaults.Data))
29802984
})
29812985
})
2986+
2987+
mt.RunOpts("24. KMS Retry Tests", qeRunOpts22, func(mt *mtest.T) {
2988+
setFailPoint := func(failure string, count int) error {
2989+
url := fmt.Sprintf("https://localhost:9003/set_failpoint/%s", failure)
2990+
var payloadBuf bytes.Buffer
2991+
body := map[string]int{"count": count}
2992+
json.NewEncoder(&payloadBuf).Encode(body)
2993+
req, err := http.NewRequest(http.MethodPost, url, &payloadBuf)
2994+
if err != nil {
2995+
return err
2996+
}
2997+
2998+
entries, err := os.ReadDir("~/")
2999+
if err != nil {
3000+
log.Fatal(err)
3001+
}
3002+
for _, e := range entries {
3003+
fmt.Println(e.Name())
3004+
}
3005+
3006+
cert, err := ioutil.ReadFile("drivers-evergreen-tools/.evergreen/x509gen/ca.pem")
3007+
if err != nil {
3008+
return err
3009+
}
3010+
3011+
certPool := x509.NewCertPool()
3012+
certPool.AppendCertsFromPEM(cert)
3013+
3014+
client := &http.Client{
3015+
Transport: &http.Transport{
3016+
TLSClientConfig: &tls.Config{
3017+
RootCAs: certPool,
3018+
},
3019+
},
3020+
}
3021+
_, err = client.Do(req)
3022+
return err
3023+
}
3024+
3025+
keyVaultClient, err := mongo.Connect(options.Client().ApplyURI(mtest.ClusterURI()))
3026+
require.NoError(mt, err, "error on Connect: %v", err)
3027+
3028+
ceo := options.ClientEncryption().
3029+
SetKeyVaultNamespace("keyvault.datakeys").
3030+
SetKmsProviders(fullKmsProvidersMap)
3031+
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3032+
require.NoError(mt, err, "error on NewClientEncryption: %v", err)
3033+
3034+
err = setFailPoint("http", 1)
3035+
require.NoError(mt, err, "mock server error: %v", err)
3036+
3037+
dkOpts := options.DataKey().SetMasterKey(
3038+
bson.D{
3039+
{"region", "foo"},
3040+
{"key", "bar"},
3041+
{"endpoint", "127.0.0.1:9003"},
3042+
},
3043+
)
3044+
var keyID bson.Binary
3045+
keyID, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
3046+
require.NoError(mt, err, "error in CreateDataKey: %v", err)
3047+
3048+
testVal := bson.RawValue{Type: bson.TypeInt32, Value: bsoncore.AppendInt32(nil, 123)}
3049+
eo := options.Encrypt().
3050+
SetKeyID(keyID).
3051+
SetAlgorithm("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic")
3052+
_, err = clientEncryption.Encrypt(context.Background(), testVal, eo)
3053+
assert.NoError(mt, err, "error in Encrypt: %v", err)
3054+
})
29823055
}
29833056

29843057
func getWatcher(mt *mtest.T, streamType mongo.StreamType, cpt *cseProseTest) watcher {

0 commit comments

Comments
 (0)