diff --git a/mongo/client_encryption.go b/mongo/client_encryption.go index ee6d62939e..07c18529a8 100644 --- a/mongo/client_encryption.go +++ b/mongo/client_encryption.go @@ -27,6 +27,7 @@ type ClientEncryption struct { crypt driver.Crypt keyVaultClient *Client keyVaultColl *Collection + closed bool } // NewClientEncryption creates a new ClientEncryption instance configured with the given options. @@ -81,6 +82,10 @@ func NewClientEncryption(keyVaultClient *Client, opts ...options.Lister[options. func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context, db *Database, coll string, createOpts options.Lister[options.CreateCollectionOptions], kmsProvider string, masterKey interface{}) (*Collection, bson.M, error) { + if ce.closed { + return nil, nil, ErrClientDisconnected + } + if createOpts == nil { return nil, nil, errors.New("nil CreateCollectionOptions") } @@ -141,6 +146,10 @@ func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context, // AddKeyAltName adds a keyAltName to the keyAltNames array of the key document in the key vault collection with the // given UUID (BSON binary subtype 0x04). Returns the previous version of the key document. func (ce *ClientEncryption) AddKeyAltName(ctx context.Context, id bson.Binary, keyAltName string) *SingleResult { + if ce.closed { + return &SingleResult{err: ErrClientDisconnected} + } + filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build() keyAltNameDoc := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build() update := bsoncore.NewDocumentBuilder().AppendDocument("$addToSet", keyAltNameDoc).Build() @@ -154,6 +163,10 @@ func (ce *ClientEncryption) CreateDataKey( kmsProvider string, opts ...options.Lister[options.DataKeyOptions], ) (bson.Binary, error) { + if ce.closed { + return bson.Binary{}, ErrClientDisconnected + } + args, err := mongoutil.NewOptions[options.DataKeyOptions](opts...) if err != nil { return bson.Binary{}, fmt.Errorf("failed to construct options from builder: %w", err) @@ -238,6 +251,9 @@ func (ce *ClientEncryption) Encrypt( val bson.RawValue, opts ...options.Lister[options.EncryptOptions], ) (bson.Binary, error) { + if ce.closed { + return bson.Binary{}, ErrClientDisconnected + } transformed := transformExplicitEncryptionOptions(opts...) subtype, data, err := ce.crypt.EncryptExplicit(ctx, bsoncore.Value{Type: bsoncore.Type(val.Type), Data: val.Value}, transformed) @@ -257,6 +273,10 @@ func (ce *ClientEncryption) Encrypt( // $gt may also be $gte. $lt may also be $lte. // Only supported for queryType "range" func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interface{}, result interface{}, opts ...options.Lister[options.EncryptOptions]) error { + if ce.closed { + return ErrClientDisconnected + } + transformed := transformExplicitEncryptionOptions(opts...) exprDoc, err := marshal(expr, nil, nil) @@ -282,6 +302,10 @@ func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interfac // Decrypt decrypts an encrypted value (BSON binary of subtype 6) and returns the original BSON value. func (ce *ClientEncryption) Decrypt(ctx context.Context, val bson.Binary) (bson.RawValue, error) { + if ce.closed { + return bson.RawValue{}, ErrClientDisconnected + } + decrypted, err := ce.crypt.DecryptExplicit(ctx, val.Subtype, val.Data) if err != nil { return bson.RawValue{}, err @@ -293,19 +317,35 @@ func (ce *ClientEncryption) Decrypt(ctx context.Context, val bson.Binary) (bson. // Close cleans up any resources associated with the ClientEncryption instance. This includes disconnecting the // key-vault Client instance. func (ce *ClientEncryption) Close(ctx context.Context) error { + if ce.closed { + return ErrClientDisconnected + } + ce.crypt.Close() - return ce.keyVaultClient.Disconnect(ctx) + err := ce.keyVaultClient.Disconnect(ctx) + if err == nil { + ce.closed = true + } + return err } // DeleteKey removes the key document with the given UUID (BSON binary subtype 0x04) from the key vault collection. // Returns the result of the internal deleteOne() operation on the key vault collection. func (ce *ClientEncryption) DeleteKey(ctx context.Context, id bson.Binary) (*DeleteResult, error) { + if ce.closed { + return nil, ErrClientDisconnected + } + filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build() return ce.keyVaultColl.DeleteOne(ctx, filter) } // GetKeyByAltName returns a key document in the key vault collection with the given keyAltName. func (ce *ClientEncryption) GetKeyByAltName(ctx context.Context, keyAltName string) *SingleResult { + if ce.closed { + return &SingleResult{err: ErrClientDisconnected} + } + filter := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build() return ce.keyVaultColl.FindOne(ctx, filter) } @@ -313,6 +353,10 @@ func (ce *ClientEncryption) GetKeyByAltName(ctx context.Context, keyAltName stri // GetKey finds a single key document with the given UUID (BSON binary subtype 0x04). Returns the result of the // internal find() operation on the key vault collection. func (ce *ClientEncryption) GetKey(ctx context.Context, id bson.Binary) *SingleResult { + if ce.closed { + return &SingleResult{err: ErrClientDisconnected} + } + filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build() return ce.keyVaultColl.FindOne(ctx, filter) } @@ -320,12 +364,20 @@ func (ce *ClientEncryption) GetKey(ctx context.Context, id bson.Binary) *SingleR // GetKeys finds all documents in the key vault collection. Returns the result of the internal find() operation on the // key vault collection. func (ce *ClientEncryption) GetKeys(ctx context.Context) (*Cursor, error) { + if ce.closed { + return nil, ErrClientDisconnected + } + return ce.keyVaultColl.Find(ctx, bson.D{}) } // RemoveKeyAltName removes a keyAltName from the keyAltNames array of the key document in the key vault collection with // the given UUID (BSON binary subtype 0x04). Returns the previous version of the key document. func (ce *ClientEncryption) RemoveKeyAltName(ctx context.Context, id bson.Binary, keyAltName string) *SingleResult { + if ce.closed { + return &SingleResult{err: ErrClientDisconnected} + } + filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build() update := bson.A{bson.D{{"$set", bson.D{{"keyAltNames", bson.D{{"$cond", bson.A{bson.D{{"$eq", bson.A{"$keyAltNames", bson.A{keyAltName}}}}, "$$REMOVE", bson.D{{"$filter", @@ -396,6 +448,10 @@ func (ce *ClientEncryption) RewrapManyDataKey( ) (*RewrapManyDataKeyResult, error) { // libmongocrypt versions 1.5.0 and 1.5.1 have a severe bug in RewrapManyDataKey. // Check if the version string starts with 1.5.0 or 1.5.1. This accounts for pre-release versions, like 1.5.0-rc0. + if ce.closed { + return nil, ErrClientDisconnected + } + libmongocryptVersion := mongocrypt.Version() if strings.HasPrefix(libmongocryptVersion, "1.5.0") || strings.HasPrefix(libmongocryptVersion, "1.5.1") { return nil, fmt.Errorf("RewrapManyDataKey requires libmongocrypt 1.5.2 or newer. Detected version: %v", libmongocryptVersion) diff --git a/mongo/client_encryption_test.go b/mongo/client_encryption_test.go new file mode 100644 index 0000000000..35acedb5b2 --- /dev/null +++ b/mongo/client_encryption_test.go @@ -0,0 +1,98 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +//go:build cse +// +build cse + +package mongo + +import ( + "context" + "testing" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt" +) + +func TestClientEncryption_ErrClientDisconnected(t *testing.T) { + t.Parallel() + + client, _ := Connect(options.Client().ApplyURI("mongodb://test")) + crypt := driver.NewCrypt(&driver.CryptOptions{MongoCrypt: &mongocrypt.MongoCrypt{}}) + + ce := &ClientEncryption{keyVaultClient: client, crypt: crypt} + _ = ce.Close(context.Background()) + + t.Run("CreateEncryptedCollection", func(t *testing.T) { + t.Parallel() + _, _, err := ce.CreateEncryptedCollection(context.Background(), nil, "", options.CreateCollection(), "", nil) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("AddKeyAltName", func(t *testing.T) { + t.Parallel() + err := ce.AddKeyAltName(context.Background(), bson.Binary{}, "").err + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("CreateDataKey", func(t *testing.T) { + t.Parallel() + _, err := ce.CreateDataKey(context.Background(), "", options.DataKey()) + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("Encrypt", func(t *testing.T) { + t.Parallel() + _, err := ce.Encrypt(context.Background(), bson.RawValue{}, options.Encrypt()) + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("EncryptExpression", func(t *testing.T) { + t.Parallel() + err := ce.EncryptExpression(context.Background(), nil, nil, options.Encrypt()) + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("Decrypt", func(t *testing.T) { + t.Parallel() + _, err := ce.Decrypt(context.Background(), bson.Binary{}) + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("Close", func(t *testing.T) { + t.Parallel() + err := ce.Close(context.Background()) + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("DeleteKey", func(t *testing.T) { + t.Parallel() + _, err := ce.DeleteKey(context.Background(), bson.Binary{}) + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("GetKeyByAltName", func(t *testing.T) { + t.Parallel() + err := ce.GetKeyByAltName(context.Background(), "").err + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("GetKey", func(t *testing.T) { + t.Parallel() + err := ce.GetKey(context.Background(), bson.Binary{}).err + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("GetKeys", func(t *testing.T) { + t.Parallel() + _, err := ce.GetKeys(context.Background()) + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("RemoveKeyAltName", func(t *testing.T) { + t.Parallel() + err := ce.RemoveKeyAltName(context.Background(), bson.Binary{}, "").err + assert.ErrorIs(t, err, ErrClientDisconnected) + }) + t.Run("RewrapManyDataKey", func(t *testing.T) { + t.Parallel() + _, err := ce.RewrapManyDataKey(context.Background(), nil, options.RewrapManyDataKey()) + assert.ErrorIs(t, err, ErrClientDisconnected) + }) +}