Skip to content

Commit 7d76ee2

Browse files
committed
PYTHON-1884 Support auto encryption in bulk write
Close KMS sockets. Call pymongocrypt's init method.
1 parent 5886631 commit 7d76ee2

File tree

4 files changed

+138
-51
lines changed

4 files changed

+138
-51
lines changed

pymongo/bulk.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
OperationFailure)
3737
from pymongo.message import (_INSERT, _UPDATE, _DELETE,
3838
_do_batched_insert,
39-
_do_bulk_write_command,
4039
_randint,
41-
_BulkWriteContext)
40+
_BulkWriteContext,
41+
_EncryptedBulkWriteContext)
4242
from pymongo.read_preferences import ReadPreference
4343
from pymongo.write_concern import WriteConcern
4444

@@ -152,8 +152,6 @@ def __init__(self, collection, ordered, bypass_document_validation):
152152
document_class=dict))
153153
self.ordered = ordered
154154
self.ops = []
155-
self.name = "%s.%s" % (collection.database.name, collection.name)
156-
self.namespace = collection.database.name + '.$cmd'
157155
self.executed = False
158156
self.bypass_doc_val = bypass_document_validation
159157
self.uses_collation = False
@@ -164,6 +162,14 @@ def __init__(self, collection, ordered, bypass_document_validation):
164162
# Extra state so that we know where to pick up on a retry attempt.
165163
self.current_run = None
166164

165+
@property
166+
def bulk_ctx_class(self):
167+
encrypter = self.collection.database.client._encrypter
168+
if encrypter and not encrypter._bypass_auto_encryption:
169+
return _EncryptedBulkWriteContext
170+
else:
171+
return _BulkWriteContext
172+
167173
def add_insert(self, document):
168174
"""Add an insert document to the list of ops.
169175
"""
@@ -271,8 +277,9 @@ def _execute_command(self, generator, write_concern, session,
271277
cmd['writeConcern'] = write_concern.document
272278
if self.bypass_doc_val and sock_info.max_wire_version >= 4:
273279
cmd['bypassDocumentValidation'] = True
274-
bwc = _BulkWriteContext(db_name, cmd, sock_info, op_id,
275-
listeners, session)
280+
bwc = self.bulk_ctx_class(
281+
db_name, cmd, sock_info, op_id, listeners, session,
282+
run.op_type, self.collection.codec_options)
276283

277284
while run.idx_offset < len(run.ops):
278285
if session:
@@ -283,16 +290,9 @@ def _execute_command(self, generator, write_concern, session,
283290
self.started_retryable_write = True
284291
session._apply_to(cmd, retryable, ReadPreference.PRIMARY)
285292
sock_info.send_cluster_time(cmd, session, client)
286-
check_keys = run.op_type == _INSERT
287293
ops = islice(run.ops, run.idx_offset, None)
288-
# Run as many ops as possible.
289-
request_id, msg, to_send = _do_bulk_write_command(
290-
self.namespace, run.op_type, cmd, ops, check_keys,
291-
self.collection.codec_options, bwc)
292-
if not to_send:
293-
raise InvalidOperation("cannot do an empty bulk write")
294-
result = bwc.write_command(request_id, msg, to_send)
295-
client._process_response(result, session)
294+
# Run as many ops as possible in one command.
295+
result, to_send = bwc.execute(ops, client)
296296

297297
# Retryable writeConcernErrors halt the execution of this run.
298298
wce = result.get('writeConcernError', {})
@@ -361,7 +361,7 @@ def execute_insert_no_results(self, sock_info, run, op_id, acknowledged):
361361
db = self.collection.database
362362
bwc = _BulkWriteContext(
363363
db.name, command, sock_info, op_id, db.client._event_listeners,
364-
session=None)
364+
None, _INSERT, self.collection.codec_options)
365365
# Legacy batched OP_INSERT.
366366
_do_batched_insert(
367367
self.collection.full_name, run.ops, True, acknowledged, concern,
@@ -383,25 +383,15 @@ def execute_op_msg_no_results(self, sock_info, generator):
383383
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
384384
('ordered', False),
385385
('writeConcern', {'w': 0})])
386-
bwc = _BulkWriteContext(db_name, cmd, sock_info, op_id,
387-
listeners, None)
386+
bwc = self.bulk_ctx_class(
387+
db_name, cmd, sock_info, op_id, listeners, None,
388+
run.op_type, self.collection.codec_options)
388389

389390
while run.idx_offset < len(run.ops):
390-
check_keys = run.op_type == _INSERT
391391
ops = islice(run.ops, run.idx_offset, None)
392392
# Run as many ops as possible.
393-
request_id, msg, to_send = _do_bulk_write_command(
394-
self.namespace, run.op_type, cmd, ops, check_keys,
395-
self.collection.codec_options, bwc)
396-
if not to_send:
397-
raise InvalidOperation("cannot do an empty bulk write")
393+
to_send = bwc.execute_unack(ops, client)
398394
run.idx_offset += len(to_send)
399-
# Though this isn't strictly a "legacy" write, the helper
400-
# handles publishing commands and sending our message
401-
# without receiving a result. Send 0 for max_doc_size
402-
# to disable size checking. Size checking is handled while
403-
# the documents are encoded to BSON.
404-
bwc.legacy_write(request_id, msg, 0, False, to_send)
405395
self.current_run = run = next(generator, None)
406396

407397
def execute_command_no_results(self, sock_info, generator):

pymongo/encryption.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,14 @@ def kms_request(self, kms_context):
7474
socket_timeout=_KMS_CONNECT_TIMEOUT,
7575
ssl_context=ctx)
7676
try:
77-
conn = _configured_socket((endpoint, _HTTPS_PORT), opts)
78-
conn.sendall(message)
77+
with _configured_socket((endpoint, _HTTPS_PORT), opts) as conn:
78+
conn.sendall(message)
79+
while kms_context.bytes_needed > 0:
80+
data = conn.recv(kms_context.bytes_needed)
81+
kms_context.feed(data)
7982
except Exception as exc:
8083
raise MongoCryptError(str(exc))
8184

82-
while kms_context.bytes_needed > 0:
83-
data = conn.recv(kms_context.bytes_needed)
84-
kms_context.feed(data)
85-
8685
def collection_info(self, database, filter):
8786
"""Get the collection info for a namespace.
8887

pymongo/message.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_dict_to_bson,
3030
_make_c_string)
3131
from bson.codec_options import DEFAULT_CODEC_OPTIONS
32+
from bson.raw_bson import _inflate_bson, DEFAULT_RAW_BSON_OPTIONS
3233
from bson.py3compat import b, StringIO
3334
from bson.son import SON
3435

@@ -47,6 +48,7 @@
4748
ProtocolError)
4849
from pymongo.read_concern import DEFAULT_READ_CONCERN
4950
from pymongo.read_preferences import ReadPreference
51+
from pymongo.write_concern import WriteConcern
5052

5153

5254
MAX_INT32 = 2147483647
@@ -862,10 +864,10 @@ class _BulkWriteContext(object):
862864

863865
__slots__ = ('db_name', 'command', 'sock_info', 'op_id',
864866
'name', 'field', 'publish', 'start_time', 'listeners',
865-
'session', 'compress')
867+
'session', 'compress', 'op_type', 'codec')
866868

867869
def __init__(self, database_name, command, sock_info, operation_id,
868-
listeners, session):
870+
listeners, session, op_type, codec):
869871
self.db_name = database_name
870872
self.command = command
871873
self.sock_info = sock_info
@@ -877,6 +879,38 @@ def __init__(self, database_name, command, sock_info, operation_id,
877879
self.start_time = datetime.datetime.now() if self.publish else None
878880
self.session = session
879881
self.compress = True if sock_info.compression_context else False
882+
self.op_type = op_type
883+
self.codec = codec
884+
885+
def _batch_command(self, docs):
886+
namespace = self.db_name + '.$cmd'
887+
request_id, msg, to_send = _do_bulk_write_command(
888+
namespace, self.op_type, self.command, docs, self.check_keys,
889+
self.codec, self)
890+
if not to_send:
891+
raise InvalidOperation("cannot do an empty bulk write")
892+
return request_id, msg, to_send
893+
894+
def execute(self, docs, client):
895+
request_id, msg, to_send = self._batch_command(docs)
896+
result = self.write_command(request_id, msg, to_send)
897+
client._process_response(result, self.session)
898+
return result, to_send
899+
900+
def execute_unack(self, docs, client):
901+
request_id, msg, to_send = self._batch_command(docs)
902+
# Though this isn't strictly a "legacy" write, the helper
903+
# handles publishing commands and sending our message
904+
# without receiving a result. Send 0 for max_doc_size
905+
# to disable size checking. Size checking is handled while
906+
# the documents are encoded to BSON.
907+
self.legacy_write(request_id, msg, 0, False, to_send)
908+
return to_send
909+
910+
@property
911+
def check_keys(self):
912+
"""Should we check keys for this operation type?"""
913+
return self.op_type == _INSERT
880914

881915
@property
882916
def max_bson_size(self):
@@ -975,6 +1009,54 @@ def _fail(self, request_id, failure, duration):
9751009
request_id, self.sock_info.address, self.op_id)
9761010

9771011

1012+
# 2MiB
1013+
_MAX_ENC_BSON_SIZE = 2 * (1024 * 1024)
1014+
# 6MB
1015+
_MAX_ENC_MESSAGE_SIZE = 6 * (1000 * 1000)
1016+
1017+
1018+
class _EncryptedBulkWriteContext(_BulkWriteContext):
1019+
__slots__ = ()
1020+
1021+
def _batch_command(self, docs):
1022+
namespace = self.db_name + '.$cmd'
1023+
msg, to_send = _encode_batched_write_command(
1024+
namespace, self.op_type, self.command, docs, self.check_keys,
1025+
self.codec, self)
1026+
if not to_send:
1027+
raise InvalidOperation("cannot do an empty bulk write")
1028+
1029+
# Chop off the OP_QUERY header to get a properly batched write command.
1030+
cmd_start = msg.index(b"\x00", 4) + 9
1031+
cmd = _inflate_bson(memoryview(msg)[cmd_start:],
1032+
DEFAULT_RAW_BSON_OPTIONS)
1033+
return cmd, to_send
1034+
1035+
def execute(self, docs, client):
1036+
cmd, to_send = self._batch_command(docs)
1037+
result = self.sock_info.command(
1038+
self.db_name, cmd, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
1039+
session=self.session, client=client)
1040+
return result, to_send
1041+
1042+
def execute_unack(self, docs, client):
1043+
cmd, to_send = self._batch_command(docs)
1044+
self.sock_info.command(
1045+
self.db_name, cmd, write_concern=WriteConcern(w=0),
1046+
session=self.session, client=client)
1047+
return to_send
1048+
1049+
@property
1050+
def max_bson_size(self):
1051+
"""A proxy for SockInfo.max_bson_size."""
1052+
return min(self.sock_info.max_bson_size, _MAX_ENC_BSON_SIZE)
1053+
1054+
@property
1055+
def max_message_size(self):
1056+
"""A proxy for SockInfo.max_message_size."""
1057+
return min(self.sock_info.max_message_size, _MAX_ENC_MESSAGE_SIZE)
1058+
1059+
9781060
def _raise_document_too_large(operation, doc_size, max_size):
9791061
"""Internal helper for raising DocumentTooLarge."""
9801062
if operation == "insert":

test/test_encryption.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,16 @@
3030
from pymongo.errors import ConfigurationError
3131
from pymongo.mongo_client import MongoClient
3232
from pymongo.encryption_options import AutoEncryptionOpts, _HAVE_PYMONGOCRYPT
33+
from pymongo.write_concern import WriteConcern
3334

3435
from test import unittest, IntegrationTest, PyMongoTestCase, client_context
36+
from test.utils import wait_until
37+
38+
39+
if _HAVE_PYMONGOCRYPT:
40+
# Load the mongocrypt library.
41+
from pymongocrypt.binding import init
42+
init(os.environ.get('MONGOCRYPT_LIB', 'mongocrypt'))
3543

3644

3745
def get_client_opts(client):
@@ -156,19 +164,27 @@ def _test_auto_encrypt(self, opts):
156164
key_vault.insert_one(data_key)
157165
self.addCleanup(key_vault.drop)
158166

159-
# Collection.insert_one auto encrypts.
160-
docs = [{'_id': 1, 'ssn': '123'},
161-
{'_id': 2, 'ssn': '456'},
162-
{'_id': 3, 'ssn': '789'}]
167+
# Collection.insert_one/insert_many auto encrypts.
168+
docs = [{'_id': 0, 'ssn': '000'},
169+
{'_id': 1, 'ssn': '111'},
170+
{'_id': 2, 'ssn': '222'},
171+
{'_id': 3, 'ssn': '333'},
172+
{'_id': 4, 'ssn': '444'},
173+
{'_id': 5, 'ssn': '555'}]
163174
encrypted_coll = client.pymongo_test.test
164-
for doc in docs:
165-
encrypted_coll.insert_one(doc)
175+
encrypted_coll.insert_one(docs[0])
176+
encrypted_coll.insert_many(docs[1:3])
177+
unack = encrypted_coll.with_options(write_concern=WriteConcern(w=0))
178+
unack.insert_one(docs[3])
179+
unack.insert_many(docs[4:], ordered=False)
180+
wait_until(lambda: self.db.test.count_documents({}) == len(docs),
181+
'insert documents with w=0')
166182

167183
# Database.command auto decrypts.
168184
res = client.pymongo_test.command(
169-
'find', 'test', filter={'ssn': '123'})
185+
'find', 'test', filter={'ssn': '000'})
170186
decrypted_docs = res['cursor']['firstBatch']
171-
self.assertEqual(decrypted_docs, [{'_id': 1, 'ssn': '123'}])
187+
self.assertEqual(decrypted_docs, [{'_id': 0, 'ssn': '000'}])
172188

173189
# Collection.find auto decrypts.
174190
decrypted_docs = list(encrypted_coll.find())
@@ -188,13 +204,13 @@ def _test_auto_encrypt(self, opts):
188204

189205
# Collection.distinct auto decrypts.
190206
decrypted_ssns = encrypted_coll.distinct('ssn')
191-
self.assertEqual(decrypted_ssns, ['123', '456', '789'])
207+
self.assertEqual(decrypted_ssns, [d['ssn'] for d in docs])
192208

193209
# Make sure the field is actually encrypted.
194-
encrypted_doc = self.db.test.find_one()
195-
self.assertEqual(encrypted_doc['_id'], 1)
196-
self.assertIsInstance(encrypted_doc['ssn'], Binary)
197-
self.assertEqual(encrypted_doc['ssn'].subtype, 6)
210+
for encrypted_doc in self.db.test.find():
211+
self.assertIsInstance(encrypted_doc['_id'], int)
212+
self.assertIsInstance(encrypted_doc['ssn'], Binary)
213+
self.assertEqual(encrypted_doc['ssn'].subtype, 6)
198214

199215
def test_auto_encrypt(self):
200216
# Configure the encrypted field via jsonSchema.

0 commit comments

Comments
 (0)