Skip to content

Commit a3ad820

Browse files
GODRIVER-3168 Retry KMS requests on transient errors. (#1887)
Co-authored-by: Preston Vasquez <[email protected]>
1 parent f99da4d commit a3ad820

File tree

8 files changed

+203
-12
lines changed

8 files changed

+203
-12
lines changed

.evergreen/config.yml

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,25 @@ functions:
561561
KMS_MOCK_SERVERS_RUNNING: "true"
562562
args: [*task-runner, evg-test-kmip]
563563

564+
run-retry-kms-requests:
565+
- command: subprocess.exec
566+
type: test
567+
params:
568+
binary: "bash"
569+
env:
570+
GO_BUILD_TAGS: cse
571+
include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY,
572+
MONGO_GO_DRIVER_COMPRESSOR]
573+
args: [*task-runner, setup-test]
574+
- command: subprocess.exec
575+
type: test
576+
params:
577+
binary: "bash"
578+
env:
579+
KMS_FAILPOINT_CA_FILE: "${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem"
580+
KMS_FAILPOINT_SERVER_RUNNING: "true"
581+
args: [*task-runner, evg-test-retry-kms-requests]
582+
564583
run-fuzz-tests:
565584
- command: subprocess.exec
566585
type: test
@@ -1468,7 +1487,7 @@ tasks:
14681487
SSL: "nossl"
14691488

14701489
- name: "test-kms-tls-invalid-cert"
1471-
tags: ["kms-tls"]
1490+
tags: ["kms-test"]
14721491
commands:
14731492
- func: bootstrap-mongo-orchestration
14741493
vars:
@@ -1484,7 +1503,7 @@ tasks:
14841503
SSL: "nossl"
14851504

14861505
- name: "test-kms-tls-invalid-hostname"
1487-
tags: ["kms-tls"]
1506+
tags: ["kms-test"]
14881507
commands:
14891508
- func: bootstrap-mongo-orchestration
14901509
vars:
@@ -1514,6 +1533,17 @@ tasks:
15141533
AUTH: "noauth"
15151534
SSL: "nossl"
15161535

1536+
- name: "test-retry-kms-requests"
1537+
tags: ["kms-test"]
1538+
commands:
1539+
- func: bootstrap-mongo-orchestration
1540+
vars:
1541+
TOPOLOGY: "server"
1542+
AUTH: "noauth"
1543+
SSL: "nossl"
1544+
- func: start-cse-servers
1545+
- func: run-retry-kms-requests
1546+
15171547
- name: "test-serverless"
15181548
tags: ["serverless"]
15191549
commands:
@@ -2201,11 +2231,11 @@ buildvariants:
22012231
tasks:
22022232
- name: ".versioned-api"
22032233

2204-
- matrix_name: "kms-tls-test"
2234+
- matrix_name: "kms-test"
22052235
matrix_spec: { version: ["7.0"], os-ssl-40: ["rhel87-64"] }
2206-
display_name: "KMS TLS ${os-ssl-40}"
2236+
display_name: "KMS TEST ${os-ssl-40}"
22072237
tasks:
2208-
- name: ".kms-tls"
2238+
- name: ".kms-test"
22092239

22102240
- matrix_name: "load-balancer-test"
22112241
tags: ["pullrequest"]

Taskfile.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ tasks:
143143
evg-test-kms:
144144
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_tls_tests >> test.suite
145145

146+
evg-test-retry-kms-requests:
147+
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_retry_tests >> test.suite
148+
146149
evg-test-load-balancers:
147150
# Load balancer should be tested with all unified tests as well as tests in the following
148151
# components: retryable reads, retryable writes, change streams, initial DNS seedlist discovery.

etc/install-libmongocrypt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This script installs libmongocrypt into an "install" directory.
44
set -eux
55

6-
LIBMONGOCRYPT_TAG="1.11.0"
6+
LIBMONGOCRYPT_TAG="1.12.0"
77

88
# Install libmongocrypt based on OS.
99
if [ "Windows_NT" = "${OS:-}" ]; then

internal/integration/client_side_encryption_prose_test.go

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"context"
1515
"crypto/tls"
1616
"encoding/base64"
17+
"encoding/json"
1718
"fmt"
1819
"io/ioutil"
1920
"net"
@@ -30,6 +31,7 @@ import (
3031
"go.mongodb.org/mongo-driver/v2/internal/handshake"
3132
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
3233
"go.mongodb.org/mongo-driver/v2/internal/integtest"
34+
"go.mongodb.org/mongo-driver/v2/internal/require"
3335
"go.mongodb.org/mongo-driver/v2/mongo"
3436
"go.mongodb.org/mongo-driver/v2/mongo/options"
3537
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
@@ -2925,7 +2927,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
29252927
}
29262928
})
29272929

2928-
mt.RunOpts("22. range explicit encryption applies defaults", qeRunOpts22, func(mt *mtest.T) {
2930+
mt.RunOpts("23. range explicit encryption applies defaults", qeRunOpts22, func(mt *mtest.T) {
29292931
err := mt.Client.Database("keyvault").Collection("datakeys").Drop(context.Background())
29302932
assert.Nil(mt, err, "error on Drop: %v", err)
29312933

@@ -2986,6 +2988,147 @@ func TestClientSideEncryptionProse(t *testing.T) {
29862988
assert.Greater(t, len(payload.Data), len(payloadDefaults.Data), "the returned payload size is expected to be greater than %d", len(payloadDefaults.Data))
29872989
})
29882990
})
2991+
2992+
mt.RunOpts("24. kms retry tests", noClientOpts, func(mt *mtest.T) {
2993+
kmsTlsTestcase := os.Getenv("KMS_FAILPOINT_SERVER_RUNNING")
2994+
if kmsTlsTestcase == "" {
2995+
mt.Skipf("Skipping test as KMS_FAILPOINT_SERVER_RUNNING is not set")
2996+
}
2997+
2998+
mt.Parallel()
2999+
3000+
tlsCAFile := os.Getenv("KMS_FAILPOINT_CA_FILE")
3001+
require.NotEqual(mt, tlsCAFile, "", "failed to load CA file")
3002+
3003+
clientAndCATlsMap := map[string]interface{}{
3004+
"tlsCAFile": tlsCAFile,
3005+
}
3006+
tlsCfg, err := options.BuildTLSConfig(clientAndCATlsMap)
3007+
require.NoError(mt, err, "BuildTLSConfig error: %v", err)
3008+
3009+
setFailPoint := func(failure string, count int) error {
3010+
url := fmt.Sprintf("https://localhost:9003/set_failpoint/%s", failure)
3011+
var payloadBuf bytes.Buffer
3012+
body := map[string]int{"count": count}
3013+
json.NewEncoder(&payloadBuf).Encode(body)
3014+
req, err := http.NewRequest(http.MethodPost, url, &payloadBuf)
3015+
if err != nil {
3016+
return err
3017+
}
3018+
3019+
client := &http.Client{
3020+
Transport: &http.Transport{TLSClientConfig: tlsCfg},
3021+
}
3022+
res, err := client.Do(req)
3023+
if err != nil {
3024+
return err
3025+
}
3026+
return res.Body.Close()
3027+
}
3028+
3029+
kmsProviders := map[string]map[string]interface{}{
3030+
"aws": {
3031+
"accessKeyId": awsAccessKeyID,
3032+
"secretAccessKey": awsSecretAccessKey,
3033+
},
3034+
"azure": {
3035+
"tenantId": azureTenantID,
3036+
"clientId": azureClientID,
3037+
"clientSecret": azureClientSecret,
3038+
"identityPlatformEndpoint": "127.0.0.1:9003",
3039+
},
3040+
"gcp": {
3041+
"email": gcpEmail,
3042+
"privateKey": gcpPrivateKey,
3043+
"endpoint": "127.0.0.1:9003",
3044+
},
3045+
}
3046+
3047+
dataKeys := []struct {
3048+
provider string
3049+
masterKey interface{}
3050+
}{
3051+
{"aws", bson.D{
3052+
{"region", "foo"},
3053+
{"key", "bar"},
3054+
{"endpoint", "127.0.0.1:9003"},
3055+
}},
3056+
{"azure", bson.D{
3057+
{"keyVaultEndpoint", "127.0.0.1:9003"},
3058+
{"keyName", "foo"},
3059+
}},
3060+
{"gcp", bson.D{
3061+
{"projectId", "foo"},
3062+
{"location", "bar"},
3063+
{"keyRing", "baz"},
3064+
{"keyName", "qux"},
3065+
{"endpoint", "127.0.0.1:9003"},
3066+
}},
3067+
}
3068+
3069+
testCases := []struct {
3070+
name string
3071+
failure string
3072+
}{
3073+
{"Case 1: createDataKey and encrypt with TCP retry", "network"},
3074+
{"Case 2: createDataKey and encrypt with HTTP retry", "http"},
3075+
}
3076+
3077+
for _, tc := range testCases {
3078+
for _, dataKey := range dataKeys {
3079+
mt.Run(fmt.Sprintf("%s_%s", tc.name, dataKey.provider), func(mt *mtest.T) {
3080+
keyVaultClient, err := mongo.Connect(options.Client().ApplyURI(mtest.ClusterURI()))
3081+
require.NoError(mt, err, "error on Connect: %v", err)
3082+
3083+
ceo := options.ClientEncryption().
3084+
SetKeyVaultNamespace(kvNamespace).
3085+
SetKmsProviders(kmsProviders).
3086+
SetTLSConfig(map[string]*tls.Config{dataKey.provider: tlsCfg})
3087+
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3088+
require.NoError(mt, err, "error on NewClientEncryption: %v", err)
3089+
3090+
err = setFailPoint(tc.failure, 1)
3091+
require.NoError(mt, err, "mock server error: %v", err)
3092+
3093+
dkOpts := options.DataKey().SetMasterKey(dataKey.masterKey)
3094+
var keyID bson.Binary
3095+
keyID, err = clientEncryption.CreateDataKey(context.Background(), dataKey.provider, dkOpts)
3096+
require.NoError(mt, err, "error in CreateDataKey: %v", err)
3097+
3098+
err = setFailPoint(tc.failure, 1)
3099+
require.NoError(mt, err, "mock server error: %v", err)
3100+
3101+
testVal := bson.RawValue{Type: bson.TypeInt32, Value: bsoncore.AppendInt32(nil, 123)}
3102+
eo := options.Encrypt().
3103+
SetKeyID(keyID).
3104+
SetAlgorithm("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic")
3105+
_, err = clientEncryption.Encrypt(context.Background(), testVal, eo)
3106+
require.NoError(mt, err, "error in Encrypt: %v", err)
3107+
})
3108+
}
3109+
}
3110+
3111+
for _, dataKey := range dataKeys {
3112+
mt.Run(fmt.Sprintf("Case 3: createDataKey fails after too many retries_%s", dataKey.provider), func(mt *mtest.T) {
3113+
keyVaultClient, err := mongo.Connect(options.Client().ApplyURI(mtest.ClusterURI()))
3114+
require.NoError(mt, err, "error on Connect: %v", err)
3115+
3116+
ceo := options.ClientEncryption().
3117+
SetKeyVaultNamespace(kvNamespace).
3118+
SetKmsProviders(kmsProviders).
3119+
SetTLSConfig(map[string]*tls.Config{dataKey.provider: tlsCfg})
3120+
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3121+
require.NoError(mt, err, "error on NewClientEncryption: %v", err)
3122+
3123+
err = setFailPoint("network", 4)
3124+
require.NoError(mt, err, "mock server error: %v", err)
3125+
3126+
dkOpts := options.DataKey().SetMasterKey(dataKey.masterKey)
3127+
_, err = clientEncryption.CreateDataKey(context.Background(), dataKey.provider, dkOpts)
3128+
require.ErrorContains(mt, err, "KMS request failed after 3 retries due to a network error")
3129+
})
3130+
}
3131+
})
29893132
}
29903133

29913134
func getWatcher(mt *mtest.T, streamType mongo.StreamType, cpt *cseProseTest) watcher {

x/mongo/driver/crypt.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ package driver
99
import (
1010
"context"
1111
"crypto/tls"
12-
"errors"
1312
"fmt"
14-
"io"
1513
"strings"
1614
"time"
1715

@@ -399,8 +397,8 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
399397

400398
res := make([]byte, bytesNeeded)
401399
bytesRead, err := conn.Read(res)
402-
if err != nil && !errors.Is(err, io.EOF) {
403-
return err
400+
if err != nil {
401+
return kmsCtx.RequestError()
404402
}
405403

406404
if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil {

x/mongo/driver/mongocrypt/mongocrypt.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) {
5353
if wrapped == nil {
5454
return nil, errors.New("could not create new mongocrypt object")
5555
}
56+
C.mongocrypt_setopt_retry_kms(wrapped, true)
5657
httpClient := opts.HTTPClient
5758
if httpClient == nil {
5859
httpClient = httputil.DefaultHTTPClient
@@ -85,7 +86,7 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) {
8586
}
8687

8788
if opts.BypassQueryAnalysis {
88-
C.mongocrypt_setopt_bypass_query_analysis(wrapped)
89+
C.mongocrypt_setopt_bypass_query_analysis(crypt.wrapped)
8990
}
9091

9192
// If loading the crypt_shared library isn't disabled, set the default library search path "$SYSTEM"

x/mongo/driver/mongocrypt/mongocrypt_kms_context.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package mongocrypt
1111

1212
// #include <mongocrypt.h>
1313
import "C"
14+
import "time"
1415

1516
// KmsContext represents a mongocrypt_kms_ctx_t handle.
1617
type KmsContext struct {
@@ -41,6 +42,8 @@ func (kc *KmsContext) KMSProvider() string {
4142

4243
// Message returns the message to send to the KMS.
4344
func (kc *KmsContext) Message() ([]byte, error) {
45+
time.Sleep(time.Duration(C.mongocrypt_kms_ctx_usleep(kc.wrapped)) * time.Microsecond)
46+
4447
msgBinary := newBinary()
4548
defer msgBinary.close()
4649

@@ -74,3 +77,11 @@ func (kc *KmsContext) createErrorFromStatus() error {
7477
C.mongocrypt_kms_ctx_status(kc.wrapped, status)
7578
return errorFromStatus(status)
7679
}
80+
81+
// RequestError returns the source of the network error for KMS requests.
82+
func (kc *KmsContext) RequestError() error {
83+
if bool(C.mongocrypt_kms_ctx_fail(kc.wrapped)) {
84+
return nil
85+
}
86+
return kc.createErrorFromStatus()
87+
}

x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@ func (kc *KmsContext) BytesNeeded() int32 {
3737
func (kc *KmsContext) FeedResponse([]byte) error {
3838
panic(cseNotSupportedMsg)
3939
}
40+
41+
// RequestError returns the source of the network error for KMS requests.
42+
func (kc *KmsContext) RequestError() error {
43+
panic(cseNotSupportedMsg)
44+
}

0 commit comments

Comments
 (0)