Skip to content

Commit 5886631

Browse files
committed
PYTHON-1884 Support auto encryption in cursors
1 parent 8888e97 commit 5886631

File tree

5 files changed

+56
-14
lines changed

5 files changed

+56
-14
lines changed

pymongo/encryption.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,19 +197,27 @@ def __init__(self, io_callbacks, opts):
197197
opts._kms_providers, schema_map))
198198
self._bypass_auto_encryption = opts._bypass_auto_encryption
199199

200-
def encrypt(self, database, cmd):
200+
def encrypt(self, database, cmd, check_keys, codec_options):
201201
"""Encrypt a MongoDB command.
202202
203203
:Parameters:
204204
- `database`: The database for this command.
205-
- `cmd`: A command as BSON.
205+
- `cmd`: A command document.
206+
- `check_keys`: If True, check `cmd` for invalid keys.
207+
- `codec_options`: The CodecOptions to use while encoding `cmd`.
206208
207209
:Returns:
208210
The encrypted command to execute.
209211
"""
210-
encrypted_cmd = self._auto_encrypter.encrypt(database, cmd)
212+
# Workaround for $clusterTime which is incompatible with check_keys.
213+
cluster_time = check_keys and cmd.pop('$clusterTime', None)
214+
encrypted_cmd = self._auto_encrypter.encrypt(
215+
database, _dict_to_bson(cmd, check_keys, codec_options))
211216
# TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
212-
return _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
217+
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
218+
if cluster_time:
219+
encrypt_cmd['$clusterTime'] = cluster_time
220+
return encrypt_cmd
213221

214222
def decrypt(self, response):
215223
"""Decrypt a MongoDB command response.

pymongo/message.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,12 @@ def as_command(self, sock_info):
305305
'readConcern', {})[
306306
'afterClusterTime'] = session.operation_time
307307
sock_info.send_cluster_time(cmd, session, self.client)
308+
# Support auto encryption
309+
client = self.client
310+
if (client._encrypter and
311+
not client._encrypter._bypass_auto_encryption):
312+
cmd = client._encrypter.encrypt(
313+
self.db, cmd, False, self.codec_options)
308314
self._as_command = cmd, self.db
309315
return self._as_command
310316

@@ -393,6 +399,12 @@ def as_command(self, sock_info):
393399
if self.session:
394400
self.session._apply_to(cmd, False, self.read_preference)
395401
sock_info.send_cluster_time(cmd, self.session, self.client)
402+
# Support auto encryption
403+
client = self.client
404+
if (client._encrypter and
405+
not client._encrypter._bypass_auto_encryption):
406+
cmd = client._encrypter.encrypt(
407+
self.db, cmd, False, self.codec_options)
396408
self._as_command = cmd, self.db
397409
return self._as_command
398410

pymongo/network.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
except ImportError:
3535
_SELECT_ERROR = OSError
3636

37-
from bson import _dict_to_bson, _decode_all_selective
37+
from bson import _decode_all_selective
3838
from bson.py3compat import PY3
3939

4040
from pymongo import helpers, message
@@ -117,13 +117,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
117117

118118
if (client and client._encrypter and
119119
not client._encrypter._bypass_auto_encryption):
120-
# Workaround for $clusterTime which is incompatible with check_keys.
121-
if check_keys:
122-
cluster_time = spec.pop('$clusterTime', None)
123120
spec = orig = client._encrypter.encrypt(
124-
dbname, _dict_to_bson(spec, check_keys, codec_options))
125-
if check_keys and cluster_time:
126-
spec['$clusterTime'] = cluster_time
121+
dbname, spec, check_keys, codec_options)
127122
# We already checked the keys, no need to do it again.
128123
check_keys = False
129124

pymongo/server.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from datetime import datetime
1818

19+
from bson import _decode_all_selective
20+
1921
from pymongo.errors import NotMasterError, OperationFailure
2022
from pymongo.helpers import _check_command_response
2123
from pymongo.message import _convert_exception
@@ -164,6 +166,15 @@ def run_operation_with_response(
164166
duration, res, operation.name, request_id,
165167
sock_info.address)
166168

169+
# Decrypt response.
170+
client = operation.client
171+
if client and client._encrypter:
172+
if use_cmd:
173+
decrypted = client._encrypter.decrypt(
174+
reply.raw_command_response())
175+
docs = _decode_all_selective(
176+
decrypted, operation.codec_options, user_fields)
177+
167178
if exhaust:
168179
response = ExhaustResponse(
169180
data=reply,

test/test_encryption.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,22 +157,38 @@ def _test_auto_encrypt(self, opts):
157157
self.addCleanup(key_vault.drop)
158158

159159
# Collection.insert_one auto encrypts.
160+
docs = [{'_id': 1, 'ssn': '123'},
161+
{'_id': 2, 'ssn': '456'},
162+
{'_id': 3, 'ssn': '789'}]
160163
encrypted_coll = client.pymongo_test.test
161-
encrypted_coll.insert_one({'_id': 1, 'ssn': '123'})
164+
for doc in docs:
165+
encrypted_coll.insert_one(doc)
162166

163167
# Database.command auto decrypts.
164168
res = client.pymongo_test.command(
165169
'find', 'test', filter={'ssn': '123'})
166170
decrypted_docs = res['cursor']['firstBatch']
167171
self.assertEqual(decrypted_docs, [{'_id': 1, 'ssn': '123'}])
168172

173+
# Collection.find auto decrypts.
174+
decrypted_docs = list(encrypted_coll.find())
175+
self.assertEqual(decrypted_docs, docs)
176+
177+
# Collection.find auto decrypts getMores.
178+
decrypted_docs = list(encrypted_coll.find(batch_size=1))
179+
self.assertEqual(decrypted_docs, docs)
180+
169181
# Collection.aggregate auto decrypts.
170182
decrypted_docs = list(encrypted_coll.aggregate([]))
171-
self.assertEqual(decrypted_docs, [{'_id': 1, 'ssn': '123'}])
183+
self.assertEqual(decrypted_docs, docs)
184+
185+
# Collection.aggregate auto decrypts getMores.
186+
decrypted_docs = list(encrypted_coll.aggregate([], batchSize=1))
187+
self.assertEqual(decrypted_docs, docs)
172188

173189
# Collection.distinct auto decrypts.
174190
decrypted_ssns = encrypted_coll.distinct('ssn')
175-
self.assertEqual(decrypted_ssns, ['123'])
191+
self.assertEqual(decrypted_ssns, ['123', '456', '789'])
176192

177193
# Make sure the field is actually encrypted.
178194
encrypted_doc = self.db.test.find_one()

0 commit comments

Comments
 (0)