Skip to content

Commit ae71872

Browse files
authored
PYTHON-3297 Test auto decryption occurs after CommandSucceeded events (#980)
1 parent 1f7f46f commit ae71872

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

test/test_encryption.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from pymongo.encryption import Algorithm, ClientEncryption, QueryType
6262
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
6363
from pymongo.errors import (
64+
AutoReconnect,
6465
BulkWriteError,
6566
ConfigurationError,
6667
EncryptionError,
@@ -1769,6 +1770,83 @@ def test_case_8(self):
17691770
self.assertEqual(len(self.topology_listener.results["opened"]), 1)
17701771

17711772

1773+
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#14-decryption-events
1774+
class TestDecryptProse(EncryptionIntegrationTest):
1775+
def setUp(self):
1776+
self.client = client_context.client
1777+
self.client.db.drop_collection("decryption_events")
1778+
self.client.keyvault.drop_collection("datakeys")
1779+
self.client.keyvault.datakeys.create_index(
1780+
"keyAltNames", unique=True, partialFilterExpression={"keyAltNames": {"$exists": True}}
1781+
)
1782+
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
1783+
1784+
self.client_encryption = ClientEncryption(
1785+
kms_providers_map, "keyvault.datakeys", self.client, CodecOptions()
1786+
)
1787+
keyID = self.client_encryption.create_data_key("local")
1788+
self.cipher_text = self.client_encryption.encrypt(
1789+
"hello", key_id=keyID, algorithm=Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic
1790+
)
1791+
if self.cipher_text[-1] == 0:
1792+
self.malformed_cipher_text = self.cipher_text[:-1] + b"1"
1793+
else:
1794+
self.malformed_cipher_text = self.cipher_text[:-1] + b"0"
1795+
self.malformed_cipher_text = Binary(self.malformed_cipher_text, 6)
1796+
opts = AutoEncryptionOpts(
1797+
key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map
1798+
)
1799+
self.listener = AllowListEventListener("aggregate")
1800+
self.encrypted_client = MongoClient(
1801+
auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener]
1802+
)
1803+
self.addCleanup(self.encrypted_client.close)
1804+
1805+
def test_01_command_error(self):
1806+
with self.fail_point(
1807+
{
1808+
"mode": {"times": 1},
1809+
"data": {"errorCode": 123, "failCommands": ["aggregate"]},
1810+
}
1811+
):
1812+
with self.assertRaises(OperationFailure):
1813+
self.encrypted_client.db.decryption_events.aggregate([])
1814+
self.assertEqual(len(self.listener.results["failed"]), 1)
1815+
for event in self.listener.results["failed"]:
1816+
self.assertEqual(event.failure["code"], 123)
1817+
1818+
def test_02_network_error(self):
1819+
with self.fail_point(
1820+
{
1821+
"mode": {"times": 1},
1822+
"data": {"errorCode": 123, "closeConnection": True, "failCommands": ["aggregate"]},
1823+
}
1824+
):
1825+
with self.assertRaises(AutoReconnect):
1826+
self.encrypted_client.db.decryption_events.aggregate([])
1827+
self.assertEqual(len(self.listener.results["failed"]), 1)
1828+
self.assertEqual(self.listener.results["failed"][0].command_name, "aggregate")
1829+
1830+
def test_03_decrypt_error(self):
1831+
self.encrypted_client.db.decryption_events.insert_one(
1832+
{"encrypted": self.malformed_cipher_text}
1833+
)
1834+
with self.assertRaises(EncryptionError):
1835+
next(self.encrypted_client.db.decryption_events.aggregate([]))
1836+
event = self.listener.results["succeeded"][0]
1837+
self.assertEqual(len(self.listener.results["failed"]), 0)
1838+
self.assertEqual(
1839+
event.reply["cursor"]["firstBatch"][0]["encrypted"], self.malformed_cipher_text
1840+
)
1841+
1842+
def test_04_decrypt_success(self):
1843+
self.encrypted_client.db.decryption_events.insert_one({"encrypted": self.cipher_text})
1844+
next(self.encrypted_client.db.decryption_events.aggregate([]))
1845+
event = self.listener.results["succeeded"][0]
1846+
self.assertEqual(len(self.listener.results["failed"]), 0)
1847+
self.assertEqual(event.reply["cursor"]["firstBatch"][0]["encrypted"], self.cipher_text)
1848+
1849+
17721850
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#bypass-spawning-mongocryptd
17731851
class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
17741852
@unittest.skipIf(

0 commit comments

Comments
 (0)