Skip to content
59 changes: 58 additions & 1 deletion mongo/client_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type ClientEncryption struct {
crypt driver.Crypt
keyVaultClient *Client
keyVaultColl *Collection
closed bool
}

// NewClientEncryption creates a new ClientEncryption instance configured with the given options.
Expand All @@ -37,6 +38,7 @@ func NewClientEncryption(keyVaultClient *Client, opts ...options.Lister[options.

ce := &ClientEncryption{
keyVaultClient: keyVaultClient,
closed: false,
}
cea, err := mongoutil.NewOptions(opts...)
if err != nil {
Expand Down Expand Up @@ -85,6 +87,10 @@ func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context,
return nil, nil, errors.New("nil CreateCollectionOptions")
}

if ce.closed {
return nil, nil, ErrClientDisconnected
}

createArgs, err := mongoutil.NewOptions[options.CreateCollectionOptions](createOpts)
if err != nil {
return nil, nil, fmt.Errorf("failed to construct options from builder: %w", err)
Expand Down Expand Up @@ -141,6 +147,10 @@ func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context,
// AddKeyAltName adds a keyAltName to the keyAltNames array of the key document in the key vault collection with the
// given UUID (BSON binary subtype 0x04). Returns the previous version of the key document.
func (ce *ClientEncryption) AddKeyAltName(ctx context.Context, id bson.Binary, keyAltName string) *SingleResult {
if ce.closed {
return &SingleResult{err: ErrClientDisconnected}
}

filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
keyAltNameDoc := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build()
update := bsoncore.NewDocumentBuilder().AppendDocument("$addToSet", keyAltNameDoc).Build()
Expand All @@ -154,6 +164,10 @@ func (ce *ClientEncryption) CreateDataKey(
kmsProvider string,
opts ...options.Lister[options.DataKeyOptions],
) (bson.Binary, error) {
if ce.closed {
return bson.Binary{}, ErrClientDisconnected
}

args, err := mongoutil.NewOptions[options.DataKeyOptions](opts...)
if err != nil {
return bson.Binary{}, fmt.Errorf("failed to construct options from builder: %w", err)
Expand Down Expand Up @@ -238,6 +252,9 @@ func (ce *ClientEncryption) Encrypt(
val bson.RawValue,
opts ...options.Lister[options.EncryptOptions],
) (bson.Binary, error) {
if ce.closed {
return bson.Binary{}, ErrClientDisconnected
}

transformed := transformExplicitEncryptionOptions(opts...)
subtype, data, err := ce.crypt.EncryptExplicit(ctx, bsoncore.Value{Type: bsoncore.Type(val.Type), Data: val.Value}, transformed)
Expand All @@ -257,6 +274,10 @@ func (ce *ClientEncryption) Encrypt(
// $gt may also be $gte. $lt may also be $lte.
// Only supported for queryType "range"
func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interface{}, result interface{}, opts ...options.Lister[options.EncryptOptions]) error {
if ce.closed {
return ErrClientDisconnected
}

transformed := transformExplicitEncryptionOptions(opts...)

exprDoc, err := marshal(expr, nil, nil)
Expand All @@ -282,6 +303,10 @@ func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interfac

// Decrypt decrypts an encrypted value (BSON binary of subtype 6) and returns the original BSON value.
func (ce *ClientEncryption) Decrypt(ctx context.Context, val bson.Binary) (bson.RawValue, error) {
if ce.closed {
return bson.RawValue{}, ErrClientDisconnected
}

decrypted, err := ce.crypt.DecryptExplicit(ctx, val.Subtype, val.Data)
if err != nil {
return bson.RawValue{}, err
Expand All @@ -293,39 +318,67 @@ func (ce *ClientEncryption) Decrypt(ctx context.Context, val bson.Binary) (bson.
// Close cleans up any resources associated with the ClientEncryption instance. This includes disconnecting the
// key-vault Client instance.
func (ce *ClientEncryption) Close(ctx context.Context) error {
if ce.closed {
return ErrClientDisconnected
}

ce.crypt.Close()
return ce.keyVaultClient.Disconnect(ctx)
err := ce.keyVaultClient.Disconnect(ctx)
if err == nil {
ce.closed = true
}
return err
}

// DeleteKey removes the key document with the given UUID (BSON binary subtype 0x04) from the key vault collection.
// Returns the result of the internal deleteOne() operation on the key vault collection.
func (ce *ClientEncryption) DeleteKey(ctx context.Context, id bson.Binary) (*DeleteResult, error) {
if ce.closed {
return nil, ErrClientDisconnected
}

filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
return ce.keyVaultColl.DeleteOne(ctx, filter)
}

// GetKeyByAltName returns a key document in the key vault collection with the given keyAltName.
func (ce *ClientEncryption) GetKeyByAltName(ctx context.Context, keyAltName string) *SingleResult {
if ce.closed {
return &SingleResult{err: ErrClientDisconnected}
}

filter := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build()
return ce.keyVaultColl.FindOne(ctx, filter)
}

// GetKey finds a single key document with the given UUID (BSON binary subtype 0x04). Returns the result of the
// internal find() operation on the key vault collection.
func (ce *ClientEncryption) GetKey(ctx context.Context, id bson.Binary) *SingleResult {
if ce.closed {
return &SingleResult{err: ErrClientDisconnected}
}

filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
return ce.keyVaultColl.FindOne(ctx, filter)
}

// GetKeys finds all documents in the key vault collection. Returns the result of the internal find() operation on the
// key vault collection.
func (ce *ClientEncryption) GetKeys(ctx context.Context) (*Cursor, error) {
if ce.closed {
return nil, ErrClientDisconnected
}

return ce.keyVaultColl.Find(ctx, bson.D{})
}

// RemoveKeyAltName removes a keyAltName from the keyAltNames array of the key document in the key vault collection with
// the given UUID (BSON binary subtype 0x04). Returns the previous version of the key document.
func (ce *ClientEncryption) RemoveKeyAltName(ctx context.Context, id bson.Binary, keyAltName string) *SingleResult {
if ce.closed {
return &SingleResult{err: ErrClientDisconnected}
}

filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
update := bson.A{bson.D{{"$set", bson.D{{"keyAltNames", bson.D{{"$cond", bson.A{bson.D{{"$eq",
bson.A{"$keyAltNames", bson.A{keyAltName}}}}, "$$REMOVE", bson.D{{"$filter",
Expand Down Expand Up @@ -396,6 +449,10 @@ func (ce *ClientEncryption) RewrapManyDataKey(
) (*RewrapManyDataKeyResult, error) {
// libmongocrypt versions 1.5.0 and 1.5.1 have a severe bug in RewrapManyDataKey.
// Check if the version string starts with 1.5.0 or 1.5.1. This accounts for pre-release versions, like 1.5.0-rc0.
if ce.closed {
return nil, ErrClientDisconnected
}

libmongocryptVersion := mongocrypt.Version()
if strings.HasPrefix(libmongocryptVersion, "1.5.0") || strings.HasPrefix(libmongocryptVersion, "1.5.1") {
return nil, fmt.Errorf("RewrapManyDataKey requires libmongocrypt 1.5.2 or newer. Detected version: %v", libmongocryptVersion)
Expand Down
74 changes: 74 additions & 0 deletions mongo/client_encryption_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package mongo

import (
"context"
"testing"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)

func TestClientEncryption(t *testing.T) {
ce := &ClientEncryption{closed: true}

t.Run("CreateEncryptedCollection", func(t *testing.T) {
// used options.CreateCollection() to avoid catching the nil CreateCollectionOptions error
_, _, err := ce.CreateEncryptedCollection(context.TODO(), nil, "", options.CreateCollection(), "", nil)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("AddKeyAltName", func(t *testing.T) {
err := ce.AddKeyAltName(context.TODO(), bson.Binary{}, "").err
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("CreateDataKey", func(t *testing.T) {
_, err := ce.CreateDataKey(context.TODO(), "", options.DataKey())
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("Encrypt", func(t *testing.T) {
_, err := ce.Encrypt(context.TODO(), bson.RawValue{}, options.Encrypt())
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("EncryptExpression", func(t *testing.T) {
err := ce.EncryptExpression(context.TODO(), nil, nil, options.Encrypt())
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("Decrypt", func(t *testing.T) {
_, err := ce.Decrypt(context.TODO(), bson.Binary{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("Close", func(t *testing.T) {
err := ce.Close(context.TODO())
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("DeleteKey", func(t *testing.T) {
_, err := ce.DeleteKey(context.TODO(), bson.Binary{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("GetKeyByAltName", func(t *testing.T) {
err := ce.GetKeyByAltName(context.TODO(), "").err
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("GetKey", func(t *testing.T) {
err := ce.GetKey(context.TODO(), bson.Binary{}).err
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("GetKeys", func(t *testing.T) {
_, err := ce.GetKeys(context.TODO())
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("RemoveKeyAltName", func(t *testing.T) {
err := ce.RemoveKeyAltName(context.TODO(), bson.Binary{}, "").err
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("RewrapManyDataKey", func(t *testing.T) {
_, err := ce.RewrapManyDataKey(context.TODO(), nil, options.RewrapManyDataKey())
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
}
Loading