Skip to content

Commit 00e69f5

Browse files
committed
update NeedKms logic
1 parent 12a1530 commit 00e69f5

File tree

6 files changed

+34
-15
lines changed

6 files changed

+34
-15
lines changed

etc/install-libmongocrypt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This script installs libmongocrypt into an "install" directory.
44
set -eux
55

6-
LIBMONGOCRYPT_TAG="1.11.0"
6+
LIBMONGOCRYPT_TAG="1.12.0"
77

88
# Install libmongocrypt based on OS.
99
if [ "Windows_NT" = "${OS:-}" ]; then

internal/integration/client_side_encryption_prose_test.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2988,6 +2988,10 @@ func TestClientSideEncryptionProse(t *testing.T) {
29882988
mt.Skipf("Skipping test as KMS_FAILPOINT_SERVERS_RUNNING is not set")
29892989
}
29902990

2991+
tlsCfg := &tls.Config{
2992+
InsecureSkipVerify: true,
2993+
}
2994+
29912995
setFailPoint := func(failure string, count int) error {
29922996
url := fmt.Sprintf("https://localhost:9003/set_failpoint/%s", failure)
29932997
var payloadBuf bytes.Buffer
@@ -2999,26 +3003,26 @@ func TestClientSideEncryptionProse(t *testing.T) {
29993003
}
30003004

30013005
client := &http.Client{
3002-
Transport: &http.Transport{
3003-
TLSClientConfig: &tls.Config{
3004-
InsecureSkipVerify: true,
3005-
},
3006-
},
3006+
Transport: &http.Transport{TLSClientConfig: tlsCfg},
30073007
}
3008-
_, err = client.Do(req)
3009-
return err
3008+
res, err := client.Do(req)
3009+
if err != nil {
3010+
return err
3011+
}
3012+
return res.Body.Close()
30103013
}
30113014

30123015
keyVaultClient, err := mongo.Connect(options.Client().ApplyURI(mtest.ClusterURI()))
30133016
require.NoError(mt, err, "error on Connect: %v", err)
30143017

30153018
ceo := options.ClientEncryption().
30163019
SetKeyVaultNamespace("keyvault.datakeys").
3017-
SetKmsProviders(fullKmsProvidersMap)
3020+
SetKmsProviders(fullKmsProvidersMap).
3021+
SetTLSConfig(map[string]*tls.Config{"aws": tlsCfg})
30183022
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
30193023
require.NoError(mt, err, "error on NewClientEncryption: %v", err)
30203024

3021-
err = setFailPoint("http", 1)
3025+
err = setFailPoint("network", 1)
30223026
require.NoError(mt, err, "mock server error: %v", err)
30233027

30243028
dkOpts := options.DataKey().SetMasterKey(
@@ -3032,7 +3036,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
30323036
keyID, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
30333037
require.NoError(mt, err, "error in CreateDataKey: %v", err)
30343038

3035-
err = setFailPoint("http", 1)
3039+
err = setFailPoint("network", 1)
30363040
require.NoError(mt, err, "mock server error: %v", err)
30373041

30383042
testVal := bson.RawValue{Type: bson.TypeInt32, Value: bsoncore.AppendInt32(nil, 123)}

x/mongo/driver/crypt.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ package driver
99
import (
1010
"context"
1111
"crypto/tls"
12-
"errors"
1312
"fmt"
14-
"io"
1513
"strings"
1614
"time"
1715

@@ -399,7 +397,10 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
399397

400398
res := make([]byte, bytesNeeded)
401399
bytesRead, err := conn.Read(res)
402-
if err != nil && !errors.Is(err, io.EOF) {
400+
if err != nil {
401+
if kmsCtx.Fail() {
402+
err = nil
403+
}
403404
return err
404405
}
405406

x/mongo/driver/mongocrypt/mongocrypt.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) {
5353
if wrapped == nil {
5454
return nil, errors.New("could not create new mongocrypt object")
5555
}
56+
C.mongocrypt_setopt_retry_kms(wrapped, true)
5657
httpClient := opts.HTTPClient
5758
if httpClient == nil {
5859
httpClient = httputil.DefaultHTTPClient
@@ -85,7 +86,7 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) {
8586
}
8687

8788
if opts.BypassQueryAnalysis {
88-
C.mongocrypt_setopt_bypass_query_analysis(wrapped)
89+
C.mongocrypt_setopt_bypass_query_analysis(crypt.wrapped)
8990
}
9091

9192
// If loading the crypt_shared library isn't disabled, set the default library search path "$SYSTEM"

x/mongo/driver/mongocrypt/mongocrypt_kms_context.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package mongocrypt
1111

1212
// #include <mongocrypt.h>
1313
import "C"
14+
import "time"
1415

1516
// KmsContext represents a mongocrypt_kms_ctx_t handle.
1617
type KmsContext struct {
@@ -41,6 +42,8 @@ func (kc *KmsContext) KMSProvider() string {
4142

4243
// Message returns the message to send to the KMS.
4344
func (kc *KmsContext) Message() ([]byte, error) {
45+
time.Sleep(time.Duration(C.mongocrypt_kms_ctx_usleep(kc.wrapped)) * time.Microsecond)
46+
4447
msgBinary := newBinary()
4548
defer msgBinary.close()
4649

@@ -74,3 +77,8 @@ func (kc *KmsContext) createErrorFromStatus() error {
7477
C.mongocrypt_kms_ctx_status(kc.wrapped, status)
7578
return errorFromStatus(status)
7679
}
80+
81+
// Fail returns a boolean indicating whether the failed request may be retried.
82+
func (kc *KmsContext) Fail() bool {
83+
return bool(C.mongocrypt_kms_ctx_fail(kc.wrapped))
84+
}

x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@ func (kc *KmsContext) BytesNeeded() int32 {
3737
func (kc *KmsContext) FeedResponse([]byte) error {
3838
panic(cseNotSupportedMsg)
3939
}
40+
41+
// Fail returns a boolean indicating whether the failed request may be retried.
42+
func (kc *KmsContext) Fail() bool {
43+
panic(cseNotSupportedMsg)
44+
}

0 commit comments

Comments
 (0)