Skip to content
58 changes: 57 additions & 1 deletion mongo/client_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type ClientEncryption struct {
crypt driver.Crypt
keyVaultClient *Client
keyVaultColl *Collection
closed bool
}

// NewClientEncryption creates a new ClientEncryption instance configured with the given options.
Expand Down Expand Up @@ -81,6 +82,10 @@ func NewClientEncryption(keyVaultClient *Client, opts ...options.Lister[options.
func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context,
db *Database, coll string, createOpts options.Lister[options.CreateCollectionOptions],
kmsProvider string, masterKey interface{}) (*Collection, bson.M, error) {
if ce.closed {
return nil, nil, ErrClientDisconnected
}

if createOpts == nil {
return nil, nil, errors.New("nil CreateCollectionOptions")
}
Expand Down Expand Up @@ -141,6 +146,10 @@ func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context,
// AddKeyAltName adds a keyAltName to the keyAltNames array of the key document in the key vault collection with the
// given UUID (BSON binary subtype 0x04). Returns the previous version of the key document.
func (ce *ClientEncryption) AddKeyAltName(ctx context.Context, id bson.Binary, keyAltName string) *SingleResult {
if ce.closed {
return &SingleResult{err: ErrClientDisconnected}
}

filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
keyAltNameDoc := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build()
update := bsoncore.NewDocumentBuilder().AppendDocument("$addToSet", keyAltNameDoc).Build()
Expand All @@ -154,6 +163,10 @@ func (ce *ClientEncryption) CreateDataKey(
kmsProvider string,
opts ...options.Lister[options.DataKeyOptions],
) (bson.Binary, error) {
if ce.closed {
return bson.Binary{}, ErrClientDisconnected
}

args, err := mongoutil.NewOptions[options.DataKeyOptions](opts...)
if err != nil {
return bson.Binary{}, fmt.Errorf("failed to construct options from builder: %w", err)
Expand Down Expand Up @@ -238,6 +251,9 @@ func (ce *ClientEncryption) Encrypt(
val bson.RawValue,
opts ...options.Lister[options.EncryptOptions],
) (bson.Binary, error) {
if ce.closed {
return bson.Binary{}, ErrClientDisconnected
}

transformed := transformExplicitEncryptionOptions(opts...)
subtype, data, err := ce.crypt.EncryptExplicit(ctx, bsoncore.Value{Type: bsoncore.Type(val.Type), Data: val.Value}, transformed)
Expand All @@ -257,6 +273,10 @@ func (ce *ClientEncryption) Encrypt(
// $gt may also be $gte. $lt may also be $lte.
// Only supported for queryType "range"
func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interface{}, result interface{}, opts ...options.Lister[options.EncryptOptions]) error {
if ce.closed {
return ErrClientDisconnected
}

transformed := transformExplicitEncryptionOptions(opts...)

exprDoc, err := marshal(expr, nil, nil)
Expand All @@ -282,6 +302,10 @@ func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interfac

// Decrypt decrypts an encrypted value (BSON binary of subtype 6) and returns the original BSON value.
func (ce *ClientEncryption) Decrypt(ctx context.Context, val bson.Binary) (bson.RawValue, error) {
if ce.closed {
return bson.RawValue{}, ErrClientDisconnected
}

decrypted, err := ce.crypt.DecryptExplicit(ctx, val.Subtype, val.Data)
if err != nil {
return bson.RawValue{}, err
Expand All @@ -293,39 +317,67 @@ func (ce *ClientEncryption) Decrypt(ctx context.Context, val bson.Binary) (bson.
// Close cleans up any resources associated with the ClientEncryption instance. This includes disconnecting the
// key-vault Client instance.
func (ce *ClientEncryption) Close(ctx context.Context) error {
if ce.closed {
return ErrClientDisconnected
}

ce.crypt.Close()
return ce.keyVaultClient.Disconnect(ctx)
err := ce.keyVaultClient.Disconnect(ctx)
if err == nil {
ce.closed = true
}
return err
}

// DeleteKey removes the key document with the given UUID (BSON binary subtype 0x04) from the key vault collection.
// Returns the result of the internal deleteOne() operation on the key vault collection.
func (ce *ClientEncryption) DeleteKey(ctx context.Context, id bson.Binary) (*DeleteResult, error) {
if ce.closed {
return nil, ErrClientDisconnected
}

filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
return ce.keyVaultColl.DeleteOne(ctx, filter)
}

// GetKeyByAltName returns a key document in the key vault collection with the given keyAltName.
func (ce *ClientEncryption) GetKeyByAltName(ctx context.Context, keyAltName string) *SingleResult {
if ce.closed {
return &SingleResult{err: ErrClientDisconnected}
}

filter := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build()
return ce.keyVaultColl.FindOne(ctx, filter)
}

// GetKey finds a single key document with the given UUID (BSON binary subtype 0x04). Returns the result of the
// internal find() operation on the key vault collection.
func (ce *ClientEncryption) GetKey(ctx context.Context, id bson.Binary) *SingleResult {
if ce.closed {
return &SingleResult{err: ErrClientDisconnected}
}

filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
return ce.keyVaultColl.FindOne(ctx, filter)
}

// GetKeys finds all documents in the key vault collection. Returns the result of the internal find() operation on the
// key vault collection.
func (ce *ClientEncryption) GetKeys(ctx context.Context) (*Cursor, error) {
if ce.closed {
return nil, ErrClientDisconnected
}

return ce.keyVaultColl.Find(ctx, bson.D{})
}

// RemoveKeyAltName removes a keyAltName from the keyAltNames array of the key document in the key vault collection with
// the given UUID (BSON binary subtype 0x04). Returns the previous version of the key document.
func (ce *ClientEncryption) RemoveKeyAltName(ctx context.Context, id bson.Binary, keyAltName string) *SingleResult {
if ce.closed {
return &SingleResult{err: ErrClientDisconnected}
}

filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
update := bson.A{bson.D{{"$set", bson.D{{"keyAltNames", bson.D{{"$cond", bson.A{bson.D{{"$eq",
bson.A{"$keyAltNames", bson.A{keyAltName}}}}, "$$REMOVE", bson.D{{"$filter",
Expand Down Expand Up @@ -396,6 +448,10 @@ func (ce *ClientEncryption) RewrapManyDataKey(
) (*RewrapManyDataKeyResult, error) {
// libmongocrypt versions 1.5.0 and 1.5.1 have a severe bug in RewrapManyDataKey.
// Check if the version string starts with 1.5.0 or 1.5.1. This accounts for pre-release versions, like 1.5.0-rc0.
if ce.closed {
return nil, ErrClientDisconnected
}

libmongocryptVersion := mongocrypt.Version()
if strings.HasPrefix(libmongocryptVersion, "1.5.0") || strings.HasPrefix(libmongocryptVersion, "1.5.1") {
return nil, fmt.Errorf("RewrapManyDataKey requires libmongocrypt 1.5.2 or newer. Detected version: %v", libmongocryptVersion)
Expand Down
93 changes: 93 additions & 0 deletions mongo/client_encryption_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package mongo

import (
"context"
"testing"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt"
)

func TestClientEncryption_ErrClientDisconnected(t *testing.T) {
client, _ := Connect(options.Client().ApplyURI("mongodb://test"))
crypt := driver.NewCrypt(&driver.CryptOptions{MongoCrypt: &mongocrypt.MongoCrypt{}})

ce := &ClientEncryption{keyVaultClient: client, crypt: crypt}
_ = ce.Close(context.Background())

t.Run("CreateEncryptedCollection", func(t *testing.T) {
t.Parallel()
_, _, err := ce.CreateEncryptedCollection(context.Background(), nil, "", options.CreateCollection(), "", nil)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("AddKeyAltName", func(t *testing.T) {
t.Parallel()
err := ce.AddKeyAltName(context.Background(), bson.Binary{}, "").err
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("CreateDataKey", func(t *testing.T) {
t.Parallel()
_, err := ce.CreateDataKey(context.Background(), "", options.DataKey())
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("Encrypt", func(t *testing.T) {
t.Parallel()
_, err := ce.Encrypt(context.Background(), bson.RawValue{}, options.Encrypt())
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("EncryptExpression", func(t *testing.T) {
t.Parallel()
err := ce.EncryptExpression(context.Background(), nil, nil, options.Encrypt())
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("Decrypt", func(t *testing.T) {
t.Parallel()
_, err := ce.Decrypt(context.Background(), bson.Binary{})
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("Close", func(t *testing.T) {
t.Parallel()
err := ce.Close(context.Background())
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("DeleteKey", func(t *testing.T) {
t.Parallel()
_, err := ce.DeleteKey(context.Background(), bson.Binary{})
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("GetKeyByAltName", func(t *testing.T) {
t.Parallel()
err := ce.GetKeyByAltName(context.Background(), "").err
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("GetKey", func(t *testing.T) {
t.Parallel()
err := ce.GetKey(context.Background(), bson.Binary{}).err
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("GetKeys", func(t *testing.T) {
t.Parallel()
_, err := ce.GetKeys(context.Background())
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("RemoveKeyAltName", func(t *testing.T) {
t.Parallel()
err := ce.RemoveKeyAltName(context.Background(), bson.Binary{}, "").err
assert.ErrorIs(t, err, ErrClientDisconnected)
})
t.Run("RewrapManyDataKey", func(t *testing.T) {
t.Parallel()
_, err := ce.RewrapManyDataKey(context.Background(), nil, options.RewrapManyDataKey())
assert.ErrorIs(t, err, ErrClientDisconnected)
})
}
Loading