Skip to content

Commit 57a5393

Browse files
authored
GODRIVER-2149 Add private API to allow overriding CSFLE functionality (#789)
1 parent 2806752 commit 57a5393

29 files changed

+250
-67
lines changed

mongo/change_stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ type changeStreamConfig struct {
9494
streamType StreamType
9595
collectionName string
9696
databaseName string
97-
crypt *driver.Crypt
97+
crypt driver.Crypt
9898
}
9999

100100
func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{},

mongo/client.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ type Client struct {
6868
keyVaultClientFLE *Client
6969
keyVaultCollFLE *Collection
7070
mongocryptdFLE *mcryptClient
71-
cryptFLE *driver.Crypt
71+
cryptFLE driver.Crypt
7272
metadataClientFLE *Client
7373
internalClientFLE *Client
7474
}
@@ -634,6 +634,8 @@ func (c *Client) configure(opts *options.ClientOptions) error {
634634
if err := c.configureAutoEncryption(opts); err != nil {
635635
return err
636636
}
637+
} else {
638+
c.cryptFLE = opts.Crypt
637639
}
638640

639641
// OCSP cache

mongo/client_encryption.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222

2323
// ClientEncryption is used to create data keys and explicitly encrypt and decrypt BSON values.
2424
type ClientEncryption struct {
25-
crypt *driver.Crypt
25+
crypt driver.Crypt
2626
keyVaultClient *Client
2727
keyVaultColl *Collection
2828
}

mongo/integration/client_side_encryption_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
2323
"go.mongodb.org/mongo-driver/mongo/options"
2424
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
25+
"go.mongodb.org/mongo-driver/x/mongo/driver"
26+
mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
2527
)
2628

2729
// createDataKeyAndEncrypt creates a data key with the alternate name @keyName.
@@ -235,3 +237,146 @@ func TestClientSideEncryptionWithExplicitSessions(t *testing.T) {
235237
assert.NotEqual(mt, lsid, session.ID(), "expected different lsid, but got %v", lsid)
236238
})
237239
}
240+
241+
// customCrypt is a test implementation of the driver.Crypt interface. It keeps track of the number of times its
242+
// methods have been called.
243+
type customCrypt struct {
244+
numEncryptCalls int
245+
numDecryptCalls int
246+
numCreateDataKeyCalls int
247+
numEncryptExplicitCalls int
248+
numDecryptExplicitCalls int
249+
numCloseCalls int
250+
numBypassAutoEncryptionCalls int
251+
}
252+
253+
var (
254+
_ driver.Crypt = (*customCrypt)(nil)
255+
mySSN = "123456789"
256+
)
257+
258+
// Encrypt encrypts the given command.
259+
func (c *customCrypt) Encrypt(_ context.Context, _ string, cmd bsoncore.Document) (bsoncore.Document, error) {
260+
c.numEncryptCalls++
261+
elems, err := cmd.Elements()
262+
if err != nil {
263+
return nil, err
264+
}
265+
266+
encryptedCmd := bsoncore.NewDocumentBuilder()
267+
for _, elem := range elems {
268+
// "encrypt" ssn element as "hidden"
269+
if elem.Key() == "ssn" {
270+
encryptedCmd = encryptedCmd.AppendString("ssn", "hidden")
271+
} else {
272+
encryptedCmd = encryptedCmd.AppendValue(elem.Key(), elem.Value())
273+
}
274+
}
275+
return encryptedCmd.Build(), nil
276+
}
277+
278+
// Decrypt decrypts the given command response.
279+
func (c *customCrypt) Decrypt(_ context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
280+
c.numDecryptCalls++
281+
elems, err := cmdResponse.Elements()
282+
if err != nil {
283+
return nil, err
284+
}
285+
286+
decryptedCmdResponse := bsoncore.NewDocumentBuilder()
287+
for _, elem := range elems {
288+
// "decrypt" ssn element as mySSN
289+
if elem.Key() == "ssn" {
290+
decryptedCmdResponse = decryptedCmdResponse.AppendString("ssn", mySSN)
291+
} else {
292+
decryptedCmdResponse = decryptedCmdResponse.AppendValue(elem.Key(), elem.Value())
293+
}
294+
}
295+
return decryptedCmdResponse.Build(), nil
296+
}
297+
298+
// CreateDataKey implements the driver.Crypt interface.
299+
func (c *customCrypt) CreateDataKey(_ context.Context, _ string, _ *mcopts.DataKeyOptions) (bsoncore.Document, error) {
300+
c.numCreateDataKeyCalls++
301+
return nil, nil
302+
}
303+
304+
// EncryptExplicit implements the driver.Crypt interface.
305+
func (c *customCrypt) EncryptExplicit(_ context.Context, _ bsoncore.Value, _ *mcopts.ExplicitEncryptionOptions) (byte, []byte, error) {
306+
c.numEncryptExplicitCalls++
307+
return 0, nil, nil
308+
}
309+
310+
// DecryptExplicit implements the driver.Crypt interface.
311+
func (c *customCrypt) DecryptExplicit(_ context.Context, _ byte, _ []byte) (bsoncore.Value, error) {
312+
c.numDecryptExplicitCalls++
313+
return bsoncore.Value{}, nil
314+
}
315+
316+
// Close implements the driver.Crypt interface.
317+
func (c *customCrypt) Close() {
318+
c.numCloseCalls++
319+
}
320+
321+
// BypassAutoEncryption implements the driver.Crypt interface.
322+
func (c *customCrypt) BypassAutoEncryption() bool {
323+
c.numBypassAutoEncryptionCalls++
324+
return false
325+
}
326+
327+
func TestClientSideEncryptionCustomCrypt(t *testing.T) {
328+
mt := mtest.New(t, mtest.NewOptions().MinServerVersion("4.2").Enterprise(true).CreateClient(false))
329+
defer mt.Close()
330+
331+
kmsProvidersMap := map[string]map[string]interface{}{
332+
"local": {"key": localMasterKey},
333+
}
334+
335+
mt.Run("auto encryption and decryption", func(mt *mtest.T) {
336+
aeOpts := options.AutoEncryption().
337+
SetKmsProviders(kmsProvidersMap).
338+
SetKeyVaultNamespace("keyvault.datakeys")
339+
clientOpts := options.Client().
340+
ApplyURI(mtest.ClusterURI()).
341+
SetAutoEncryptionOptions(aeOpts)
342+
cc := &customCrypt{}
343+
clientOpts.Crypt = cc
344+
testutil.AddTestServerAPIVersion(clientOpts)
345+
346+
client, err := mongo.Connect(mtest.Background, clientOpts)
347+
defer client.Disconnect(mtest.Background)
348+
assert.Nil(mt, err, "Connect error: %v", err)
349+
350+
coll := client.Database("db").Collection("coll")
351+
defer func() { _ = coll.Drop(mtest.Background) }()
352+
353+
doc := bson.D{{"foo", "bar"}, {"ssn", mySSN}}
354+
_, err = coll.InsertOne(mtest.Background, doc)
355+
assert.Nil(mt, err, "InsertOne error: %v", err)
356+
357+
res := coll.FindOne(mtest.Background, bson.D{{"foo", "bar"}})
358+
assert.Nil(mt, res.Err(), "FindOne error: %v", err)
359+
360+
rawRes, err := res.DecodeBytes()
361+
assert.Nil(mt, err, "DecodeBytes error: %v", err)
362+
ssn, ok := rawRes.Lookup("ssn").StringValueOK()
363+
assert.True(mt, ok, "expected 'ssn' value to be type string, got %T", ssn)
364+
assert.Equal(mt, ssn, mySSN, "expected 'ssn' value %q, got %q", mySSN, ssn)
365+
366+
// Assert customCrypt methods are called the correct number of times.
367+
assert.Equal(mt, cc.numEncryptCalls, 1,
368+
"expected 1 call to Encrypt, got %v", cc.numEncryptCalls)
369+
assert.Equal(mt, cc.numDecryptCalls, 1,
370+
"expected 1 call to Decrypt, got %v", cc.numDecryptCalls)
371+
assert.Equal(mt, cc.numCreateDataKeyCalls, 0,
372+
"expected 0 calls to CreateDataKey, got %v", cc.numCreateDataKeyCalls)
373+
assert.Equal(mt, cc.numEncryptExplicitCalls, 0,
374+
"expected 0 calls to EncryptExplicit, got %v", cc.numEncryptExplicitCalls)
375+
assert.Equal(mt, cc.numDecryptExplicitCalls, 0,
376+
"expected 0 calls to DecryptExplicit, got %v", cc.numDecryptExplicitCalls)
377+
assert.Equal(mt, cc.numCloseCalls, 0,
378+
"expected 0 calls to Close, got %v", cc.numCloseCalls)
379+
assert.Equal(mt, cc.numBypassAutoEncryptionCalls, 2,
380+
"expected 2 calls to BypassAutoEncryption, got %v", cc.numBypassAutoEncryptionCalls)
381+
})
382+
}

mongo/options/clientoptions.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ type ClientOptions struct {
138138
// release.
139139
AuthenticateToAnything *bool
140140

141+
// Crypt specifies a custom driver.Crypt to be used to encrypt and decrypt documents. The default is no
142+
// encryption.
143+
//
144+
// Deprecated: This option is for internal use only and should not be set (see GODRIVER-2149). It may be
145+
// changed or removed in any release.
146+
Crypt driver.Crypt
147+
141148
// Deployment specifies a custom deployment to use for the new Client.
142149
//
143150
// Deprecated: This option is for internal use only and should not be set. It may be changed or removed in any
@@ -828,6 +835,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions {
828835
if opt.ConnectTimeout != nil {
829836
c.ConnectTimeout = opt.ConnectTimeout
830837
}
838+
if opt.Crypt != nil {
839+
c.Crypt = opt.Crypt
840+
}
831841
if opt.HeartbeatInterval != nil {
832842
c.HeartbeatInterval = opt.HeartbeatInterval
833843
}

x/mongo/driver/batch_cursor.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type BatchCursor struct {
3232
firstBatch bool
3333
cmdMonitor *event.CommandMonitor
3434
postBatchResumeToken bsoncore.Document
35-
crypt *Crypt
35+
crypt Crypt
3636
serverAPI *ServerAPIOptions
3737

3838
// legacy server (< 3.2) fields
@@ -130,7 +130,7 @@ type CursorOptions struct {
130130
MaxTimeMS int64
131131
Limit int32
132132
CommandMonitor *event.CommandMonitor
133-
Crypt *Crypt
133+
Crypt Crypt
134134
ServerAPI *ServerAPIOptions
135135
}
136136

x/mongo/driver/crypt.go

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,46 @@ type CryptOptions struct {
4343
BypassAutoEncryption bool
4444
}
4545

46-
// Crypt consumes the libmongocrypt.MongoCrypt type to iterate the mongocrypt state machine and perform encryption
46+
// Crypt is an interface implemented by types that can encrypt and decrypt instances of
47+
// bsoncore.Document.
48+
//
49+
// Users should rely on the driver's crypt type (used by default) for encryption and decryption
50+
// unless they are perfectly confident in another implementation of Crypt.
51+
type Crypt interface {
52+
// Encrypt encrypts the given command.
53+
Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
54+
// Decrypt decrypts the given command response.
55+
Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error)
56+
// CreateDataKey creates a data key using the given KMS provider and options.
57+
CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error)
58+
// EncryptExplicit encrypts the given value with the given options.
59+
EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error)
60+
// DecryptExplicit decrypts the given encrypted value.
61+
DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error)
62+
// Close cleans up any resources associated with the Crypt instance.
63+
Close()
64+
// BypassAutoEncryption returns true if auto-encryption should be bypassed.
65+
BypassAutoEncryption() bool
66+
}
67+
68+
// crypt consumes the libmongocrypt.MongoCrypt type to iterate the mongocrypt state machine and perform encryption
4769
// and decryption.
48-
type Crypt struct {
70+
type crypt struct {
4971
mongoCrypt *mongocrypt.MongoCrypt
5072
collInfoFn CollectionInfoFn
5173
keyFn KeyRetrieverFn
5274
markFn MarkCommandFn
5375

54-
BypassAutoEncryption bool
76+
bypassAutoEncryption bool
5577
}
5678

5779
// NewCrypt creates a new Crypt instance configured with the given AutoEncryptionOptions.
58-
func NewCrypt(opts *CryptOptions) (*Crypt, error) {
59-
c := &Crypt{
80+
func NewCrypt(opts *CryptOptions) (Crypt, error) {
81+
c := &crypt{
6082
collInfoFn: opts.CollInfoFn,
6183
keyFn: opts.KeyFn,
6284
markFn: opts.MarkFn,
63-
BypassAutoEncryption: opts.BypassAutoEncryption,
85+
bypassAutoEncryption: opts.BypassAutoEncryption,
6486
}
6587

6688
mongocryptOpts := options.MongoCrypt().SetKmsProviders(opts.KmsProviders).SetLocalSchemaMap(opts.SchemaMap)
@@ -74,8 +96,8 @@ func NewCrypt(opts *CryptOptions) (*Crypt, error) {
7496
}
7597

7698
// Encrypt encrypts the given command.
77-
func (c *Crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error) {
78-
if c.BypassAutoEncryption {
99+
func (c *crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error) {
100+
if c.bypassAutoEncryption {
79101
return cmd, nil
80102
}
81103

@@ -89,7 +111,7 @@ func (c *Crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (
89111
}
90112

91113
// Decrypt decrypts the given command response.
92-
func (c *Crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
114+
func (c *crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
93115
cryptCtx, err := c.mongoCrypt.CreateDecryptionContext(cmdResponse)
94116
if err != nil {
95117
return nil, err
@@ -100,7 +122,7 @@ func (c *Crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bso
100122
}
101123

102124
// CreateDataKey creates a data key using the given KMS provider and options.
103-
func (c *Crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error) {
125+
func (c *crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error) {
104126
cryptCtx, err := c.mongoCrypt.CreateDataKeyContext(kmsProvider, opts)
105127
if err != nil {
106128
return nil, err
@@ -111,7 +133,7 @@ func (c *Crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *opt
111133
}
112134

113135
// EncryptExplicit encrypts the given value with the given options.
114-
func (c *Crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error) {
136+
func (c *crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error) {
115137
idx, doc := bsoncore.AppendDocumentStart(nil)
116138
doc = bsoncore.AppendValueElement(doc, "v", val)
117139
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
@@ -132,7 +154,7 @@ func (c *Crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *o
132154
}
133155

134156
// DecryptExplicit decrypts the given encrypted value.
135-
func (c *Crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error) {
157+
func (c *crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error) {
136158
idx, doc := bsoncore.AppendDocumentStart(nil)
137159
doc = bsoncore.AppendBinaryElement(doc, "v", subtype, data)
138160
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
@@ -152,11 +174,15 @@ func (c *Crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte)
152174
}
153175

154176
// Close cleans up any resources associated with the Crypt instance.
155-
func (c *Crypt) Close() {
177+
func (c *crypt) Close() {
156178
c.mongoCrypt.Close()
157179
}
158180

159-
func (c *Crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Context, db string) (bsoncore.Document, error) {
181+
func (c *crypt) BypassAutoEncryption() bool {
182+
return c.bypassAutoEncryption
183+
}
184+
185+
func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Context, db string) (bsoncore.Document, error) {
160186
var err error
161187
for {
162188
state := cryptCtx.State()
@@ -180,7 +206,7 @@ func (c *Crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Co
180206
}
181207
}
182208

183-
func (c *Crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
209+
func (c *crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
184210
op, err := cryptCtx.NextOperation()
185211
if err != nil {
186212
return err
@@ -199,7 +225,7 @@ func (c *Crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context
199225
return cryptCtx.CompleteOperation()
200226
}
201227

202-
func (c *Crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
228+
func (c *crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
203229
op, err := cryptCtx.NextOperation()
204230
if err != nil {
205231
return err
@@ -216,7 +242,7 @@ func (c *Crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, d
216242
return cryptCtx.CompleteOperation()
217243
}
218244

219-
func (c *Crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
245+
func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
220246
op, err := cryptCtx.NextOperation()
221247
if err != nil {
222248
return err
@@ -236,7 +262,7 @@ func (c *Crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context)
236262
return cryptCtx.CompleteOperation()
237263
}
238264

239-
func (c *Crypt) decryptKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
265+
func (c *crypt) decryptKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
240266
for {
241267
kmsCtx := cryptCtx.NextKmsContext()
242268
if kmsCtx == nil {
@@ -251,7 +277,7 @@ func (c *Crypt) decryptKeys(ctx context.Context, cryptCtx *mongocrypt.Context) e
251277
return cryptCtx.FinishKmsContexts()
252278
}
253279

254-
func (c *Crypt) decryptKey(ctx context.Context, kmsCtx *mongocrypt.KmsContext) error {
280+
func (c *crypt) decryptKey(ctx context.Context, kmsCtx *mongocrypt.KmsContext) error {
255281
host, err := kmsCtx.HostName()
256282
if err != nil {
257283
return err

0 commit comments

Comments
 (0)