Skip to content

Commit a4a5bc3

Browse files
authored
GODRIVER-2529 Return error if a ClientEncryption is used after Close (#1785)
1 parent 3577ed5 commit a4a5bc3

File tree

2 files changed

+155
-1
lines changed

2 files changed

+155
-1
lines changed

mongo/client_encryption.go

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type ClientEncryption struct {
2727
crypt driver.Crypt
2828
keyVaultClient *Client
2929
keyVaultColl *Collection
30+
closed bool
3031
}
3132

3233
// NewClientEncryption creates a new ClientEncryption instance configured with the given options.
@@ -81,6 +82,10 @@ func NewClientEncryption(keyVaultClient *Client, opts ...options.Lister[options.
8182
func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context,
8283
db *Database, coll string, createOpts options.Lister[options.CreateCollectionOptions],
8384
kmsProvider string, masterKey interface{}) (*Collection, bson.M, error) {
85+
if ce.closed {
86+
return nil, nil, ErrClientDisconnected
87+
}
88+
8489
if createOpts == nil {
8590
return nil, nil, errors.New("nil CreateCollectionOptions")
8691
}
@@ -141,6 +146,10 @@ func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context,
141146
// AddKeyAltName adds a keyAltName to the keyAltNames array of the key document in the key vault collection with the
142147
// given UUID (BSON binary subtype 0x04). Returns the previous version of the key document.
143148
func (ce *ClientEncryption) AddKeyAltName(ctx context.Context, id bson.Binary, keyAltName string) *SingleResult {
149+
if ce.closed {
150+
return &SingleResult{err: ErrClientDisconnected}
151+
}
152+
144153
filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
145154
keyAltNameDoc := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build()
146155
update := bsoncore.NewDocumentBuilder().AppendDocument("$addToSet", keyAltNameDoc).Build()
@@ -154,6 +163,10 @@ func (ce *ClientEncryption) CreateDataKey(
154163
kmsProvider string,
155164
opts ...options.Lister[options.DataKeyOptions],
156165
) (bson.Binary, error) {
166+
if ce.closed {
167+
return bson.Binary{}, ErrClientDisconnected
168+
}
169+
157170
args, err := mongoutil.NewOptions[options.DataKeyOptions](opts...)
158171
if err != nil {
159172
return bson.Binary{}, fmt.Errorf("failed to construct options from builder: %w", err)
@@ -238,6 +251,9 @@ func (ce *ClientEncryption) Encrypt(
238251
val bson.RawValue,
239252
opts ...options.Lister[options.EncryptOptions],
240253
) (bson.Binary, error) {
254+
if ce.closed {
255+
return bson.Binary{}, ErrClientDisconnected
256+
}
241257

242258
transformed := transformExplicitEncryptionOptions(opts...)
243259
subtype, data, err := ce.crypt.EncryptExplicit(ctx, bsoncore.Value{Type: bsoncore.Type(val.Type), Data: val.Value}, transformed)
@@ -257,6 +273,10 @@ func (ce *ClientEncryption) Encrypt(
257273
// $gt may also be $gte. $lt may also be $lte.
258274
// Only supported for queryType "range"
259275
func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interface{}, result interface{}, opts ...options.Lister[options.EncryptOptions]) error {
276+
if ce.closed {
277+
return ErrClientDisconnected
278+
}
279+
260280
transformed := transformExplicitEncryptionOptions(opts...)
261281

262282
exprDoc, err := marshal(expr, nil, nil)
@@ -282,6 +302,10 @@ func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interfac
282302

283303
// Decrypt decrypts an encrypted value (BSON binary of subtype 6) and returns the original BSON value.
284304
func (ce *ClientEncryption) Decrypt(ctx context.Context, val bson.Binary) (bson.RawValue, error) {
305+
if ce.closed {
306+
return bson.RawValue{}, ErrClientDisconnected
307+
}
308+
285309
decrypted, err := ce.crypt.DecryptExplicit(ctx, val.Subtype, val.Data)
286310
if err != nil {
287311
return bson.RawValue{}, err
@@ -293,39 +317,67 @@ func (ce *ClientEncryption) Decrypt(ctx context.Context, val bson.Binary) (bson.
293317
// Close cleans up any resources associated with the ClientEncryption instance. This includes disconnecting the
294318
// key-vault Client instance.
295319
func (ce *ClientEncryption) Close(ctx context.Context) error {
320+
if ce.closed {
321+
return ErrClientDisconnected
322+
}
323+
296324
ce.crypt.Close()
297-
return ce.keyVaultClient.Disconnect(ctx)
325+
err := ce.keyVaultClient.Disconnect(ctx)
326+
if err == nil {
327+
ce.closed = true
328+
}
329+
return err
298330
}
299331

300332
// DeleteKey removes the key document with the given UUID (BSON binary subtype 0x04) from the key vault collection.
301333
// Returns the result of the internal deleteOne() operation on the key vault collection.
302334
func (ce *ClientEncryption) DeleteKey(ctx context.Context, id bson.Binary) (*DeleteResult, error) {
335+
if ce.closed {
336+
return nil, ErrClientDisconnected
337+
}
338+
303339
filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
304340
return ce.keyVaultColl.DeleteOne(ctx, filter)
305341
}
306342

307343
// GetKeyByAltName returns a key document in the key vault collection with the given keyAltName.
308344
func (ce *ClientEncryption) GetKeyByAltName(ctx context.Context, keyAltName string) *SingleResult {
345+
if ce.closed {
346+
return &SingleResult{err: ErrClientDisconnected}
347+
}
348+
309349
filter := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build()
310350
return ce.keyVaultColl.FindOne(ctx, filter)
311351
}
312352

313353
// GetKey finds a single key document with the given UUID (BSON binary subtype 0x04). Returns the result of the
314354
// internal find() operation on the key vault collection.
315355
func (ce *ClientEncryption) GetKey(ctx context.Context, id bson.Binary) *SingleResult {
356+
if ce.closed {
357+
return &SingleResult{err: ErrClientDisconnected}
358+
}
359+
316360
filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
317361
return ce.keyVaultColl.FindOne(ctx, filter)
318362
}
319363

320364
// GetKeys finds all documents in the key vault collection. Returns the result of the internal find() operation on the
321365
// key vault collection.
322366
func (ce *ClientEncryption) GetKeys(ctx context.Context) (*Cursor, error) {
367+
if ce.closed {
368+
return nil, ErrClientDisconnected
369+
}
370+
323371
return ce.keyVaultColl.Find(ctx, bson.D{})
324372
}
325373

326374
// RemoveKeyAltName removes a keyAltName from the keyAltNames array of the key document in the key vault collection with
327375
// the given UUID (BSON binary subtype 0x04). Returns the previous version of the key document.
328376
func (ce *ClientEncryption) RemoveKeyAltName(ctx context.Context, id bson.Binary, keyAltName string) *SingleResult {
377+
if ce.closed {
378+
return &SingleResult{err: ErrClientDisconnected}
379+
}
380+
329381
filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
330382
update := bson.A{bson.D{{"$set", bson.D{{"keyAltNames", bson.D{{"$cond", bson.A{bson.D{{"$eq",
331383
bson.A{"$keyAltNames", bson.A{keyAltName}}}}, "$$REMOVE", bson.D{{"$filter",
@@ -396,6 +448,10 @@ func (ce *ClientEncryption) RewrapManyDataKey(
396448
) (*RewrapManyDataKeyResult, error) {
397449
// libmongocrypt versions 1.5.0 and 1.5.1 have a severe bug in RewrapManyDataKey.
398450
// Check if the version string starts with 1.5.0 or 1.5.1. This accounts for pre-release versions, like 1.5.0-rc0.
451+
if ce.closed {
452+
return nil, ErrClientDisconnected
453+
}
454+
399455
libmongocryptVersion := mongocrypt.Version()
400456
if strings.HasPrefix(libmongocryptVersion, "1.5.0") || strings.HasPrefix(libmongocryptVersion, "1.5.1") {
401457
return nil, fmt.Errorf("RewrapManyDataKey requires libmongocrypt 1.5.2 or newer. Detected version: %v", libmongocryptVersion)

mongo/client_encryption_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright (C) MongoDB, Inc. 2024-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
//go:build cse
8+
// +build cse
9+
10+
package mongo
11+
12+
import (
13+
"context"
14+
"testing"
15+
16+
"go.mongodb.org/mongo-driver/v2/bson"
17+
"go.mongodb.org/mongo-driver/v2/internal/assert"
18+
"go.mongodb.org/mongo-driver/v2/mongo/options"
19+
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
20+
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt"
21+
)
22+
23+
func TestClientEncryption_ErrClientDisconnected(t *testing.T) {
24+
t.Parallel()
25+
26+
client, _ := Connect(options.Client().ApplyURI("mongodb://test"))
27+
crypt := driver.NewCrypt(&driver.CryptOptions{MongoCrypt: &mongocrypt.MongoCrypt{}})
28+
29+
ce := &ClientEncryption{keyVaultClient: client, crypt: crypt}
30+
_ = ce.Close(context.Background())
31+
32+
t.Run("CreateEncryptedCollection", func(t *testing.T) {
33+
t.Parallel()
34+
_, _, err := ce.CreateEncryptedCollection(context.Background(), nil, "", options.CreateCollection(), "", nil)
35+
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
36+
assert.ErrorIs(t, err, ErrClientDisconnected)
37+
})
38+
t.Run("AddKeyAltName", func(t *testing.T) {
39+
t.Parallel()
40+
err := ce.AddKeyAltName(context.Background(), bson.Binary{}, "").err
41+
assert.ErrorIs(t, err, ErrClientDisconnected)
42+
})
43+
t.Run("CreateDataKey", func(t *testing.T) {
44+
t.Parallel()
45+
_, err := ce.CreateDataKey(context.Background(), "", options.DataKey())
46+
assert.ErrorIs(t, err, ErrClientDisconnected)
47+
})
48+
t.Run("Encrypt", func(t *testing.T) {
49+
t.Parallel()
50+
_, err := ce.Encrypt(context.Background(), bson.RawValue{}, options.Encrypt())
51+
assert.ErrorIs(t, err, ErrClientDisconnected)
52+
})
53+
t.Run("EncryptExpression", func(t *testing.T) {
54+
t.Parallel()
55+
err := ce.EncryptExpression(context.Background(), nil, nil, options.Encrypt())
56+
assert.ErrorIs(t, err, ErrClientDisconnected)
57+
})
58+
t.Run("Decrypt", func(t *testing.T) {
59+
t.Parallel()
60+
_, err := ce.Decrypt(context.Background(), bson.Binary{})
61+
assert.ErrorIs(t, err, ErrClientDisconnected)
62+
})
63+
t.Run("Close", func(t *testing.T) {
64+
t.Parallel()
65+
err := ce.Close(context.Background())
66+
assert.ErrorIs(t, err, ErrClientDisconnected)
67+
})
68+
t.Run("DeleteKey", func(t *testing.T) {
69+
t.Parallel()
70+
_, err := ce.DeleteKey(context.Background(), bson.Binary{})
71+
assert.ErrorIs(t, err, ErrClientDisconnected)
72+
})
73+
t.Run("GetKeyByAltName", func(t *testing.T) {
74+
t.Parallel()
75+
err := ce.GetKeyByAltName(context.Background(), "").err
76+
assert.ErrorIs(t, err, ErrClientDisconnected)
77+
})
78+
t.Run("GetKey", func(t *testing.T) {
79+
t.Parallel()
80+
err := ce.GetKey(context.Background(), bson.Binary{}).err
81+
assert.ErrorIs(t, err, ErrClientDisconnected)
82+
})
83+
t.Run("GetKeys", func(t *testing.T) {
84+
t.Parallel()
85+
_, err := ce.GetKeys(context.Background())
86+
assert.ErrorIs(t, err, ErrClientDisconnected)
87+
})
88+
t.Run("RemoveKeyAltName", func(t *testing.T) {
89+
t.Parallel()
90+
err := ce.RemoveKeyAltName(context.Background(), bson.Binary{}, "").err
91+
assert.ErrorIs(t, err, ErrClientDisconnected)
92+
})
93+
t.Run("RewrapManyDataKey", func(t *testing.T) {
94+
t.Parallel()
95+
_, err := ce.RewrapManyDataKey(context.Background(), nil, options.RewrapManyDataKey())
96+
assert.ErrorIs(t, err, ErrClientDisconnected)
97+
})
98+
}

0 commit comments

Comments
 (0)