Skip to content

Commit 2aff6b1

Browse files
committed
Reset Protocol after each message
1 parent b94afb4 commit 2aff6b1

File tree

4 files changed

+23
-31
lines changed

4 files changed

+23
-31
lines changed

pymongo/network_layer.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,6 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: int = 2**14):
483483
self._timeout = timeout
484484
self._is_compressed = False
485485
self._compressor_id: Optional[int] = None
486-
self._need_compression_header = False
487486
self._max_message_size = MAX_MESSAGE_SIZE
488487
self._request_id: Optional[int] = None
489488
self._closed = asyncio.get_running_loop().create_future()
@@ -539,36 +538,38 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
539538
if read_waiter in self._done_messages:
540539
self._done_messages.remove(read_waiter)
541540
if message:
542-
start, end, op_code, overflow, overflow_index = (
541+
start, end, op_code, is_compressed, compressor_id, overflow, overflow_index = (
543542
message[0],
544543
message[1],
545544
message[2],
546545
message[3],
547546
message[4],
547+
message[5],
548+
message[6],
548549
)
549-
if self._is_compressed:
550+
if is_compressed:
550551
header_size = 25
551552
else:
552553
header_size = 16
553554
if overflow is not None:
554-
if self._is_compressed and self._compressor_id is not None:
555+
if is_compressed and compressor_id is not None:
555556
return decompress(
556557
memoryview(
557558
bytearray(self._buffer[start + header_size : self._end_index])
558559
+ bytearray(overflow[:overflow_index])
559560
),
560-
self._compressor_id,
561+
compressor_id,
561562
), op_code
562563
else:
563564
return memoryview(
564565
bytearray(self._buffer[start + header_size : self._end_index])
565566
+ bytearray(overflow[:overflow_index])
566567
), op_code
567568
else:
568-
if self._is_compressed and self._compressor_id is not None:
569+
if is_compressed and compressor_id is not None:
569570
return decompress(
570571
memoryview(self._buffer[start + header_size : end]),
571-
self._compressor_id,
572+
compressor_id,
572573
), op_code
573574
else:
574575
return memoryview(self._buffer[start + header_size : end]), op_code
@@ -624,6 +625,8 @@ def buffer_updated(self, nbytes: int) -> None:
624625
self._start_index,
625626
self._body_size + self._start_index,
626627
self._op_code,
628+
self._is_compressed,
629+
self._compressor_id,
627630
self._overflow,
628631
self._overflow_index,
629632
)
@@ -635,18 +638,20 @@ def buffer_updated(self, nbytes: int) -> None:
635638
else:
636639
self._start_index += self._body_size
637640
self._done_messages.append(result)
641+
# Reset internal state to expect a new message
642+
self._expecting_header = True
643+
self._body_size = 0
644+
self._op_code = None # type: ignore[assignment]
645+
self._overflow = None
646+
self._overflow_index = 0
647+
self._is_compressed = False
648+
self._compressor_id = None
638649
# If at least one header's worth of data remains after the current message, reprocess all leftover data
639650
if self._end_index - self._start_index >= 16:
640651
self._read_waiter = asyncio.get_running_loop().create_future()
641652
self._pending_messages.append(self._read_waiter)
642653
nbytes_reprocess = self._end_index - self._start_index
643654
self._end_index -= nbytes_reprocess
644-
# Reset internal state to expect a new message
645-
self._expecting_header = True
646-
self._body_size = 0
647-
self._op_code = None # type: ignore[assignment]
648-
self._overflow = None
649-
self._overflow_index = 0
650655
self.buffer_updated(nbytes_reprocess)
651656
# Pause reading to avoid storing an arbitrary number of messages in memory before necessary
652657
self.transport.pause_reading()
@@ -673,12 +678,9 @@ def process_header(self) -> tuple[int, int]:
673678
)
674679
if op_code == 2012:
675680
self._is_compressed = True
676-
if self._end_index >= 25:
677-
op_code, _, self._compressor_id = _UNPACK_COMPRESSION_HEADER(
678-
self._buffer[self._start_index + 16 : self._start_index + 25]
679-
)
680-
else:
681-
self._need_compression_header = True
681+
op_code, _, self._compressor_id = _UNPACK_COMPRESSION_HEADER(
682+
self._buffer[self._start_index + 16 : self._start_index + 25]
683+
)
682684

683685
return length, op_code
684686

test/asynchronous/test_client_bulk_write.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -676,14 +676,8 @@ async def test_timeout_in_multi_batch_bulk_write(self):
676676
listener = OvertCommandListener()
677677
client = await self.async_rs_or_single_client(
678678
event_listeners=[listener],
679-
readConcernLevel="majority",
680-
readPreference="primary",
681679
timeoutMS=2000,
682-
w="majority",
683680
)
684-
# Initialize the client with a larger timeout to help make test less flakey
685-
with pymongo.timeout(10):
686-
await client.admin.command("ping")
687681
with self.assertRaises(ClientBulkWriteException) as context:
688682
await client.bulk_write(models=models)
689683
self.assertIsInstance(context.exception.error, NetworkTimeout)

test/test_client_bulk_write.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,14 +672,8 @@ def test_timeout_in_multi_batch_bulk_write(self):
672672
listener = OvertCommandListener()
673673
client = self.rs_or_single_client(
674674
event_listeners=[listener],
675-
readConcernLevel="majority",
676-
readPreference="primary",
677675
timeoutMS=2000,
678-
w="majority",
679676
)
680-
# Initialize the client with a larger timeout to help make test less flakey
681-
with pymongo.timeout(10):
682-
client.admin.command("ping")
683677
with self.assertRaises(ClientBulkWriteException) as context:
684678
client.bulk_write(models=models)
685679
self.assertIsInstance(context.exception.error, NetworkTimeout)

test/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def __init__(self):
318318
self.cancel_context = _CancellationContext()
319319
self.more_to_come = False
320320
self.id = random.randint(0, 100)
321+
self.server_connection_id = random.randint(0, 100)
321322

322323
def close_conn(self, reason):
323324
pass
@@ -334,6 +335,7 @@ def __init__(self):
334335
self.cancel_context = _CancellationContext()
335336
self.more_to_come = False
336337
self.id = random.randint(0, 100)
338+
self.server_connection_id = random.randint(0, 100)
337339

338340
def close_conn(self, reason):
339341
pass

0 commit comments

Comments
 (0)