Skip to content

Commit e6eecb0

Browse files
committed
PYTHON-1884 Implement auto encryption spec tests
Skip test for symbol type which pymongo converts to string. Fix {} comparison with RawBSONDocument in command events. Add support for $$type assertions. Nicer message in check_events. Support errorContains with empty string. Move custom data files to custom/.
1 parent 743042d commit e6eecb0

File tree

9 files changed

+236
-41
lines changed

9 files changed

+236
-41
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ pymongo.egg-info/
1313
*.so
1414
*.egg
1515
.tox
16+
mongocryptd.pid

pymongo/encryption.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
_inflate_bson)
3131
from bson.son import SON
3232

33-
from pymongo.errors import ServerSelectionTimeoutError
33+
from pymongo.errors import (EncryptionError,
34+
ServerSelectionTimeoutError)
3435
from pymongo.mongo_client import MongoClient
3536
from pymongo.pool import _configured_socket, PoolOptions
3637
from pymongo.ssl_support import get_ssl_context
@@ -210,8 +211,11 @@ def encrypt(self, database, cmd, check_keys, codec_options):
210211
"""
211212
# Workaround for $clusterTime which is incompatible with check_keys.
212213
cluster_time = check_keys and cmd.pop('$clusterTime', None)
213-
encrypted_cmd = self._auto_encrypter.encrypt(
214-
database, _dict_to_bson(cmd, check_keys, codec_options))
214+
encoded_cmd = _dict_to_bson(cmd, check_keys, codec_options)
215+
try:
216+
encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd)
217+
except MongoCryptError as exc:
218+
raise EncryptionError(exc)
215219
# TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
216220
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
217221
if cluster_time:
@@ -227,7 +231,10 @@ def decrypt(self, response):
227231
:Returns:
228232
The decrypted command response.
229233
"""
230-
return self._auto_encrypter.decrypt(response)
234+
try:
235+
return self._auto_encrypter.decrypt(response)
236+
except MongoCryptError as exc:
237+
raise EncryptionError(exc)
231238

232239
def close(self):
233240
"""Cleanup resources."""

pymongo/errors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,10 @@ class DocumentTooLarge(InvalidDocument):
247247
"""Raised when an encoded document is too large for the connected server.
248248
"""
249249
pass
250+
251+
252+
class EncryptionError(OperationFailure):
253+
"""Raised when encryption or decryption fails.
254+
255+
.. versionadded:: 3.9
256+
"""

test/test_encryption.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from pymongo.write_concern import WriteConcern
3434

3535
from test import unittest, IntegrationTest, PyMongoTestCase, client_context
36-
from test.utils import wait_until
36+
from test.utils import TestCreator, camel_to_snake_args, wait_until
37+
from test.utils_spec_runner import SpecRunner
3738

3839

3940
if _HAVE_PYMONGOCRYPT:
@@ -129,8 +130,10 @@ def setUpClass(cls):
129130

130131

131132
# Location of JSON test files.
132-
TEST_PATH = os.path.join(
133+
BASE = os.path.join(
133134
os.path.dirname(os.path.realpath(__file__)), 'client-side-encryption')
135+
CUSTOM_PATH = os.path.join(BASE, 'custom')
136+
SPEC_PATH = os.path.join(BASE, 'spec')
134137

135138
OPTS = CodecOptions(uuid_representation=STANDARD)
136139

@@ -139,7 +142,7 @@ def setUpClass(cls):
139142

140143

141144
def read(filename):
142-
with open(os.path.join(TEST_PATH, filename)) as fp:
145+
with open(os.path.join(CUSTOM_PATH, filename)) as fp:
143146
return fp.read()
144147

145148

@@ -231,5 +234,104 @@ def test_auto_encrypt_local_schema_map(self):
231234
self._test_auto_encrypt(opts)
232235

233236

237+
# Spec tests
238+
239+
AWS_CREDS = {
240+
'accessKeyId': os.environ.get('FLE_AWS_KEY', ''),
241+
'secretAccessKey': os.environ.get('FLE_AWS_SECRET', '')
242+
}
243+
244+
245+
class TestSpec(SpecRunner):
246+
247+
@classmethod
248+
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed')
249+
def setUpClass(cls):
250+
super(TestSpec, cls).setUpClass()
251+
252+
def parse_auto_encrypt_opts(self, opts):
253+
"""Parse clientOptions.autoEncryptOpts."""
254+
opts = camel_to_snake_args(opts)
255+
kms_providers = opts['kms_providers']
256+
if 'aws' in kms_providers:
257+
kms_providers['aws'] = AWS_CREDS
258+
if not any(AWS_CREDS.values()):
259+
self.skipTest('AWS environment credentials are not set')
260+
if 'key_vault_namespace' not in opts:
261+
opts['key_vault_namespace'] = 'admin.datakeys'
262+
opts = dict(opts)
263+
return AutoEncryptionOpts(**opts)
264+
265+
def parse_client_options(self, opts):
266+
"""Override clientOptions parsing to support autoEncryptOpts."""
267+
encrypt_opts = opts.pop('autoEncryptOpts')
268+
if encrypt_opts:
269+
opts['auto_encryption_opts'] = self.parse_auto_encrypt_opts(
270+
encrypt_opts)
271+
272+
return super(TestSpec, self).parse_client_options(opts)
273+
274+
def get_object_name(self, op):
275+
"""Default object is collection."""
276+
return op.get('object', 'collection')
277+
278+
def maybe_skip_scenario(self, test):
279+
super(TestSpec, self).maybe_skip_scenario(test)
280+
if 'type=symbol' in test['description'].lower():
281+
raise unittest.SkipTest(
282+
'PyMongo does not support the symbol type')
283+
284+
def setup_scenario(self, scenario_def):
285+
"""Override a test's setup."""
286+
key_vault_data = scenario_def['key_vault_data']
287+
if key_vault_data:
288+
coll = client_context.client.get_database(
289+
'admin',
290+
write_concern=WriteConcern(w='majority'),
291+
codec_options=OPTS)['datakeys']
292+
coll.drop()
293+
coll.insert_many(key_vault_data)
294+
295+
db_name = self.get_scenario_db_name(scenario_def)
296+
coll_name = self.get_scenario_coll_name(scenario_def)
297+
db = client_context.client.get_database(
298+
db_name, write_concern=WriteConcern(w='majority'),
299+
codec_options=OPTS)
300+
coll = db[coll_name]
301+
coll.drop()
302+
json_schema = scenario_def['json_schema']
303+
if json_schema:
304+
db.create_collection(
305+
coll_name,
306+
validator={'$jsonSchema': json_schema}, codec_options=OPTS)
307+
else:
308+
db.create_collection(coll_name)
309+
310+
if scenario_def['data']:
311+
# Load data.
312+
coll.insert_many(scenario_def['data'])
313+
314+
def allowable_errors(self, op):
315+
"""Override expected error classes."""
316+
errors = super(TestSpec, self).allowable_errors(op)
317+
# An updateOne test expects encryption to error when no $ operator
318+
# appears but pymongo raises a client side ValueError in this case.
319+
if op['name'] == 'updateOne':
320+
errors += (ValueError,)
321+
return errors
322+
323+
324+
def create_test(scenario_def, test, name):
325+
@client_context.require_test_commands
326+
def run_scenario(self):
327+
self.run_scenario(scenario_def, test)
328+
329+
return run_scenario
330+
331+
332+
test_creator = TestCreator(create_test, TestSpec, SPEC_PATH)
333+
test_creator.create_tests()
334+
335+
234336
if __name__ == "__main__":
235337
unittest.main()

test/test_retryable_reads.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
sys.path[0:0] = [""]
2121

2222
from pymongo.mongo_client import MongoClient
23+
from pymongo.write_concern import WriteConcern
2324

2425
from test import unittest, client_context, PyMongoTestCase
2526
from test.utils import TestCreator
@@ -67,6 +68,28 @@ def maybe_skip_scenario(self, test):
6768
raise unittest.SkipTest(
6869
'PyMongo does not support %s' % (name,))
6970

71+
def get_scenario_coll_name(self, scenario_def):
72+
"""Override a test's collection name to support GridFS tests."""
73+
if 'bucket_name' in scenario_def:
74+
return scenario_def['bucket_name']
75+
return super(TestSpec, self).get_scenario_coll_name(scenario_def)
76+
77+
def setup_scenario(self, scenario_def):
78+
"""Override a test's setup to support GridFS tests."""
79+
if 'bucket_name' in scenario_def:
80+
db_name = self.get_scenario_db_name(scenario_def)
81+
db = client_context.client.get_database(
82+
db_name, write_concern=WriteConcern(w='majority'))
83+
# Create a bucket for the retryable reads GridFS tests.
84+
client_context.client.drop_database(db_name)
85+
if scenario_def['data']:
86+
data = scenario_def['data']
87+
# Load data.
88+
db['fs.chunks'].insert_many(data['fs.chunks'])
89+
db['fs.files'].insert_many(data['fs.files'])
90+
else:
91+
super(TestSpec, self).setup_scenario(scenario_def)
92+
7093

7194
def create_test(scenario_def, test, name):
7295
@client_context.require_test_commands

test/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,12 @@ def create_tests(self):
321321

322322
for filename in filenames:
323323
with open(os.path.join(dirpath, filename)) as scenario_stream:
324+
# Use tz_aware=False to match how CodecOptions decodes
325+
# dates.
326+
opts = json_util.JSONOptions(tz_aware=False)
324327
scenario_def = ScenarioDict(
325-
json_util.loads(scenario_stream.read()))
328+
json_util.loads(scenario_stream.read(),
329+
json_options=opts))
326330

327331
test_type = os.path.splitext(filename)[0]
328332

0 commit comments

Comments
 (0)