Skip to content

Commit 06a8e69

Browse files
committed
update NeedKms logic
1 parent 12a1530 commit 06a8e69

File tree

8 files changed

+37
-2
lines changed

8 files changed

+37
-2
lines changed

Taskfile.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ tasks:
142142
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_tls_tests >> test.suite
143143

144144
evg-test-retry-kms-requests:
145-
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_retry_tests >> test.suite
145+
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout 300s ./internal/integration -run TestClientSideEncryptionProse/kms_retry_tests >> test.suite
146146

147147
evg-test-load-balancers:
148148
# Load balancer should be tested with all unified tests as well as tests in the following

etc/install-libmongocrypt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This script installs libmongocrypt into an "install" directory.
44
set -eux
55

6-
LIBMONGOCRYPT_TAG="1.11.0"
6+
LIBMONGOCRYPT_TAG="1.12.0"
77

88
# Install libmongocrypt based on OS.
99
if [ "Windows_NT" = "${OS:-}" ]; then

mongo/client_encryption.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,14 @@ func (ce *ClientEncryption) CreateDataKey(
188188
}
189189

190190
// create data key document
191+
fmt.Println("CreateDataKey")
191192
dataKeyDoc, err := ce.crypt.CreateDataKey(ctx, kmsProvider, co)
192193
if err != nil {
193194
return bson.Binary{}, err
194195
}
195196

196197
// insert key into key vault
198+
fmt.Println("InsertOne")
197199
_, err = ce.keyVaultColl.InsertOne(ctx, dataKeyDoc)
198200
if err != nil {
199201
return bson.Binary{}, err

x/mongo/driver/crypt.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Co
260260
var err error
261261
for {
262262
state := cryptCtx.State()
263+
fmt.Println("state", state)
263264
switch state {
264265
case mongocrypt.NeedMongoCollInfo:
265266
err = c.collectionInfo(ctx, cryptCtx, db)
@@ -341,6 +342,7 @@ func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context)
341342
}
342343

343344
func (c *crypt) decryptKeys(cryptCtx *mongocrypt.Context) error {
345+
c.mongoCrypt.EnableRetry()
344346
for {
345347
kmsCtx := cryptCtx.NextKmsContext()
346348
if kmsCtx == nil {
@@ -376,8 +378,10 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
376378
if tlsCfg == nil {
377379
tlsCfg = &tls.Config{MinVersion: tls.VersionTLS12}
378380
}
381+
fmt.Println("dial", addr)
379382
conn, err := tls.Dial("tcp", addr, tlsCfg)
380383
if err != nil {
384+
fmt.Println("dial error", err)
381385
return err
382386
}
383387
defer func() {
@@ -388,6 +392,7 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
388392
return err
389393
}
390394
if _, err = conn.Write(msg); err != nil {
395+
fmt.Println("conn write", err)
391396
return err
392397
}
393398

@@ -400,6 +405,11 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
400405
res := make([]byte, bytesNeeded)
401406
bytesRead, err := conn.Read(res)
402407
if err != nil && !errors.Is(err, io.EOF) {
408+
fail := kmsCtx.Fail()
409+
fmt.Println("conn read", err, fail)
410+
if fail {
411+
continue
412+
}
403413
return err
404414
}
405415

x/mongo/driver/mongocrypt/mongocrypt.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,3 +522,8 @@ func (m *MongoCrypt) GetKmsProviders(ctx context.Context) (bsoncore.Document, er
522522
}
523523
return builder.Build(), nil
524524
}
525+
526+
// EnableRetry enables retry.
527+
func (m *MongoCrypt) EnableRetry() {
528+
_ = C.mongocrypt_setopt_retry_kms(m.wrapped, true)
529+
}

x/mongo/driver/mongocrypt/mongocrypt_kms_context.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package mongocrypt
1111

1212
// #include <mongocrypt.h>
1313
import "C"
14+
import "time"
1415

1516
// KmsContext represents a mongocrypt_kms_ctx_t handle.
1617
type KmsContext struct {
@@ -41,6 +42,8 @@ func (kc *KmsContext) KMSProvider() string {
4142

4243
// Message returns the message to send to the KMS.
4344
func (kc *KmsContext) Message() ([]byte, error) {
45+
time.Sleep(time.Duration(C.mongocrypt_kms_ctx_usleep(kc.wrapped)) * time.Microsecond)
46+
4447
msgBinary := newBinary()
4548
defer msgBinary.close()
4649

@@ -74,3 +77,8 @@ func (kc *KmsContext) createErrorFromStatus() error {
7477
C.mongocrypt_kms_ctx_status(kc.wrapped, status)
7578
return errorFromStatus(status)
7679
}
80+
81+
// Fail returns a boolean indicating whether the failed request may be retried.
82+
func (kc *KmsContext) Fail() bool {
83+
return bool(C.mongocrypt_kms_ctx_fail(kc.wrapped))
84+
}

x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@ func (kc *KmsContext) BytesNeeded() int32 {
3737
func (kc *KmsContext) FeedResponse([]byte) error {
3838
panic(cseNotSupportedMsg)
3939
}
40+
41+
// Fail returns a boolean indicating whether the failed request may be retried.
42+
func (kc *KmsContext) Fail() bool {
43+
panic(cseNotSupportedMsg)
44+
}

x/mongo/driver/mongocrypt/mongocrypt_not_enabled.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,8 @@ func (m *MongoCrypt) Close() {
9595
func (m *MongoCrypt) GetKmsProviders(context.Context) (bsoncore.Document, error) {
9696
panic(cseNotSupportedMsg)
9797
}
98+
99+
// EnableRetry enables retry.
100+
func (m *MongoCrypt) EnableRetry() {
101+
panic(cseNotSupportedMsg)
102+
}

0 commit comments

Comments
 (0)