Skip to content

Commit fb86f65

Browse files
committed
PYTHON-1950 Restrict key_id to Binary subtype 4
1 parent 69ec553 commit fb86f65

File tree

2 files changed

+63
-36
lines changed

2 files changed

+63
-36
lines changed

pymongo/encryption.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Client side encryption."""
1616

17+
import contextlib
1718
import functools
1819
import subprocess
1920
import uuid
@@ -32,7 +33,9 @@
3233

3334
from bson import _bson_to_dict, _dict_to_bson, decode, encode
3435
from bson.codec_options import CodecOptions
35-
from bson.binary import STANDARD, Binary
36+
from bson.binary import (Binary,
37+
STANDARD,
38+
UUID_SUBTYPE)
3639
from bson.errors import BSONError
3740
from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS,
3841
RawBSONDocument,
@@ -58,20 +61,30 @@
5861
uuid_representation=STANDARD)
5962

6063

64+
@contextlib.contextmanager
65+
def _wrap_encryption_errors_ctx():
66+
"""Context manager to wrap encryption related errors."""
67+
try:
68+
yield
69+
except BSONError:
70+
# BSON encoding/decoding errors are unrelated to encryption so
71+
# we should propagate them unchanged.
72+
raise
73+
except Exception as exc:
74+
raise EncryptionError(exc)
75+
76+
6177
def _wrap_encryption_errors(encryption_func=None):
62-
"""Decorator to wrap encryption related errors with EncryptionError."""
63-
@functools.wraps(encryption_func)
64-
def wrap_encryption_errors(*args, **kwargs):
65-
try:
66-
return encryption_func(*args, **kwargs)
67-
except BSONError:
68-
# BSON encoding/decoding errors are unrelated to encryption so
69-
# we should propagate them unchanged.
70-
raise
71-
except Exception as exc:
72-
raise EncryptionError(exc)
78+
"""Decorator or context manager to wrap encryption related errors."""
79+
if encryption_func:
80+
@functools.wraps(encryption_func)
81+
def wrap_encryption_errors(*args, **kwargs):
82+
with _wrap_encryption_errors_ctx():
83+
return encryption_func(*args, **kwargs)
7384

74-
return wrap_encryption_errors
85+
return wrap_encryption_errors
86+
else:
87+
return _wrap_encryption_errors_ctx()
7588

7689

7790
class _EncryptionIO(MongoCryptCallback):
@@ -190,8 +203,11 @@ def insert_data_key(self, data_key):
190203
"""
191204
# insert does not return the inserted _id when given a RawBSONDocument.
192205
doc = _bson_to_dict(data_key, _DATA_KEY_OPTS)
206+
if not isinstance(doc.get('_id'), uuid.UUID):
207+
raise TypeError(
208+
'data_key _id must be a bson.binary.Binary with subtype 4')
193209
res = self.key_vault_coll.insert_one(doc)
194-
return res.inserted_id
210+
return Binary(res.inserted_id.bytes, subtype=UUID_SUBTYPE)
195211

196212
def bson_encode(self, doc):
197213
"""Encode a document to BSON.
@@ -406,7 +422,6 @@ def create_data_key(self, kms_provider, master_key=None,
406422
return self._encryption.create_data_key(
407423
kms_provider, master_key=master_key, key_alt_names=key_alt_names)
408424

409-
@_wrap_encryption_errors
410425
def encrypt(self, value, algorithm, key_id=None, key_alt_name=None):
411426
"""Encrypt a BSON value with a given key and algorithm.
412427
@@ -417,28 +432,25 @@ def encrypt(self, value, algorithm, key_id=None, key_alt_name=None):
417432
- `value`: The BSON value to encrypt.
418433
- `algorithm` (string): The encryption algorithm to use. See
419434
:class:`Algorithm` for some valid options.
420-
- `key_id`: Identifies a data key by ``_id`` which must be a UUID
421-
or a :class:`~bson.binary.Binary` with subtype 4.
435+
- `key_id`: Identifies a data key by ``_id`` which must be a
436+
:class:`~bson.binary.Binary` with subtype 4 (
437+
:attr:`~bson.binary.UUID_SUBTYPE`).
422438
- `key_alt_name`: Identifies a key vault document by 'keyAltName'.
423439
424440
:Returns:
425441
The encrypted value, a :class:`~bson.binary.Binary` with subtype 6.
426442
"""
427-
doc = encode({'v': value}, codec_options=self._codec_options)
428-
if isinstance(key_id, uuid.UUID):
429-
raw_key_id = key_id.bytes
430-
else:
431-
raw_key_id = key_id
432-
encrypted_doc = self._encryption.encrypt(
433-
doc, algorithm, key_id=raw_key_id, key_alt_name=key_alt_name)
434-
return decode(encrypted_doc)['v']
443+
if (key_id is not None and not (
444+
isinstance(key_id, Binary) and
445+
key_id.subtype == UUID_SUBTYPE)):
446+
raise TypeError(
447+
'key_id must be a bson.binary.Binary with subtype 4')
435448

436-
@_wrap_encryption_errors
437-
def _decrypt(self, value):
438-
"""Internal decrypt helper."""
439-
doc = encode({'v': value})
440-
decrypted_doc = self._encryption.decrypt(doc)
441-
return decode(decrypted_doc, codec_options=self._codec_options)['v']
449+
doc = encode({'v': value}, codec_options=self._codec_options)
450+
with _wrap_encryption_errors_ctx():
451+
encrypted_doc = self._encryption.encrypt(
452+
doc, algorithm, key_id=key_id, key_alt_name=key_alt_name)
453+
return decode(encrypted_doc)['v']
442454

443455
def decrypt(self, value):
444456
"""Decrypt an encrypted value.
@@ -454,7 +466,11 @@ def decrypt(self, value):
454466
raise TypeError(
455467
'value to decrypt must be a bson.binary.Binary with subtype 6')
456468

457-
return self._decrypt(value)
469+
with _wrap_encryption_errors_ctx():
470+
doc = encode({'v': value})
471+
decrypted_doc = self._encryption.decrypt(doc)
472+
return decode(decrypted_doc,
473+
codec_options=self._codec_options)['v']
458474

459475
def close(self):
460476
"""Release resources."""

test/test_encryption.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def assertEncrypted(self, val):
141141
self.assertIsInstance(val, Binary)
142142
self.assertEqual(val.subtype, 6)
143143

144+
def assertBinaryUUID(self, val):
145+
self.assertIsInstance(val, Binary)
146+
self.assertEqual(val.subtype, UUID_SUBTYPE)
147+
144148

145149
# Location of JSON test files.
146150
BASE = os.path.join(
@@ -266,13 +270,13 @@ def test_encrypt_decrypt(self):
266270
# Create the encrypted field's data key.
267271
key_id = client_encryption.create_data_key(
268272
'local', key_alt_names=['name'])
269-
self.assertIsInstance(key_id, uuid.UUID)
273+
self.assertBinaryUUID(key_id)
270274
self.assertTrue(key_vault.find_one({'_id': key_id}))
271275

272276
# Create an unused data key to make sure filtering works.
273277
unused_key_id = client_encryption.create_data_key(
274278
'local', key_alt_names=['unused'])
275-
self.assertIsInstance(unused_key_id, uuid.UUID)
279+
self.assertBinaryUUID(unused_key_id)
276280
self.assertTrue(key_vault.find_one({'_id': unused_key_id}))
277281

278282
doc = {'_id': 0, 'ssn': '000'}
@@ -302,6 +306,13 @@ def test_validation(self):
302306
with self.assertRaisesRegex(TypeError, msg):
303307
client_encryption.decrypt(Binary(b'123'))
304308

309+
msg = 'key_id must be a bson.binary.Binary with subtype 4'
310+
algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic
311+
with self.assertRaisesRegex(TypeError, msg):
312+
client_encryption.encrypt('str', algo, key_id=uuid.uuid4())
313+
with self.assertRaisesRegex(TypeError, msg):
314+
client_encryption.encrypt('str', algo, key_id=Binary(b'123'))
315+
305316
def test_bson_errors(self):
306317
client_encryption = ClientEncryption(
307318
KMS_PROVIDERS, 'admin.datakeys', client_context.client, OPTS)
@@ -529,7 +540,7 @@ def test_data_key(self):
529540
# Local create data key.
530541
local_datakey_id = client_encryption.create_data_key(
531542
'local', key_alt_names=['local_altname'])
532-
self.assertIsInstance(local_datakey_id, uuid.UUID)
543+
self.assertBinaryUUID(local_datakey_id)
533544
docs = list(vault.find({'_id': local_datakey_id}))
534545
self.assertEqual(len(docs), 1)
535546
self.assertEqual(docs[0]['masterKey']['provider'], 'local')
@@ -560,7 +571,7 @@ def test_data_key(self):
560571
}
561572
aws_datakey_id = client_encryption.create_data_key(
562573
'aws', master_key=master_key, key_alt_names=['aws_altname'])
563-
self.assertIsInstance(aws_datakey_id, uuid.UUID)
574+
self.assertBinaryUUID(aws_datakey_id)
564575
docs = list(vault.find({'_id': aws_datakey_id}))
565576
self.assertEqual(len(docs), 1)
566577
self.assertEqual(docs[0]['masterKey']['provider'], 'aws')

0 commit comments

Comments
 (0)