Skip to content

Commit 63cfbbc

Browse files
committed
Don't unpack messages inside protocol
1 parent c2e62ce commit 63cfbbc

File tree

3 files changed

+128
-174
lines changed

3 files changed

+128
-174
lines changed

pymongo/asynchronous/network.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from pymongo.monitoring import _is_speculative_authenticate
4747
from pymongo.network_layer import (
4848
_UNPACK_COMPRESSION_HEADER,
49-
_UNPACK_HEADER, async_sendall_stream,
49+
_UNPACK_HEADER, async_sendall, async_receive_data,
5050
)
5151

5252
if TYPE_CHECKING:
@@ -201,15 +201,15 @@ async def command_stream(
201201

202202
try:
203203
write_start = time.monotonic()
204-
await async_sendall_stream(conn, msg)
204+
await async_sendall(conn, msg)
205205
write_elapsed = time.monotonic() - write_start
206206
if use_op_msg and unacknowledged:
207207
# Unacknowledged, fake a successful command response.
208208
reply = None
209209
response_doc: _DocumentOut = {"ok": 1}
210210
else:
211211
read_start = time.monotonic()
212-
reply = await receive_message_stream(conn, request_id)
212+
reply = await receive_message(conn, request_id)
213213
read_elapsed = time.monotonic() - read_start
214214
# if name == "insert":
215215
# TOTAL.append(write_elapsed + read_elapsed)
@@ -316,7 +316,7 @@ async def command_stream(
316316
return response_doc # type: ignore[return-value]
317317

318318

319-
async def receive_message_stream(
319+
async def receive_message(
320320
conn: AsyncConnectionStream, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
321321
) -> Union[_OpReply, _OpMsg]:
322322
"""Receive a raw BSON message or raise socket.error."""
@@ -330,35 +330,27 @@ async def receive_message_stream(
330330
# deadline = None
331331
deadline = None
332332
# Ignore the response's request id.
333-
# data = bytearray(max_message_size)
334-
conn.conn[1].reset()
335-
# try:
336-
data, op_code = await asyncio.wait_for(conn.conn[1].read(), timeout=5)
337-
# except asyncio.TimeoutError:
338-
# print(f"Timed out on read in {asyncio.current_task()}. Start of reading memory at {conn.conn[1].ready_offset}, start of writing memory at {conn.conn[1].empty_offset}, max of {MAX_MESSAGE_SIZE}, messages: {conn.conn[1]._messages}")
339-
340-
341-
# length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data_stream(conn, 16, deadline))
342-
# # No request_id for exhaust cursor "getMore".
343-
# if request_id is not None:
344-
# if request_id != response_to:
345-
# raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
346-
# if length <= 16:
347-
# raise ProtocolError(
348-
# f"Message length ({length!r}) not longer than standard message header size (16)"
349-
# )
350-
# if length > max_message_size:
351-
# raise ProtocolError(
352-
# f"Message length ({length!r}) is larger than server max "
353-
# f"message size ({max_message_size!r})"
354-
# )
355-
# if op_code == 2012:
356-
# op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
357-
# await async_receive_data_stream(conn, 9, deadline)
358-
# )
359-
# data = decompress(await async_receive_data_stream(conn, length - 25, deadline), compressor_id)
360-
# else:
361-
# data = await async_receive_data_stream(conn, length - 16, deadline)
333+
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
334+
# No request_id for exhaust cursor "getMore".
335+
if request_id is not None:
336+
if request_id != response_to:
337+
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
338+
if length <= 16:
339+
raise ProtocolError(
340+
f"Message length ({length!r}) not longer than standard message header size (16)"
341+
)
342+
if length > max_message_size:
343+
raise ProtocolError(
344+
f"Message length ({length!r}) is larger than server max "
345+
f"message size ({max_message_size!r})"
346+
)
347+
if op_code == 2012:
348+
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
349+
await async_receive_data(conn, 9, deadline)
350+
)
351+
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
352+
else:
353+
data = await async_receive_data(conn, length - 16, deadline)
362354

363355
try:
364356
unpack_reply = _UNPACK_REPLY[op_code]
@@ -367,4 +359,3 @@ async def receive_message_stream(
367359
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
368360
) from None
369361
return unpack_reply(data)
370-

pymongo/asynchronous/pool.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from pymongo import _csot, helpers_shared
4343
from pymongo.asynchronous.client_session import _validate_session_write_concern
4444
from pymongo.asynchronous.helpers import _handle_reauth
45-
from pymongo.asynchronous.network import command_stream, receive_message_stream
45+
from pymongo.asynchronous.network import command_stream, receive_message
4646
from pymongo.common import (
4747
MAX_BSON_SIZE,
4848
MAX_MESSAGE_SIZE,
@@ -80,7 +80,7 @@
8080
ConnectionCheckOutFailedReason,
8181
ConnectionClosedReason,
8282
)
83-
from pymongo.network_layer import async_sendall_stream, _UNPACK_HEADER, PyMongoProtocol
83+
from pymongo.network_layer import async_sendall, _UNPACK_HEADER, PyMongoProtocol
8484
from pymongo.pool_options import PoolOptions
8585
from pymongo.read_preferences import ReadPreference
8686
from pymongo.server_api import _add_to_command
@@ -576,7 +576,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
576576
)
577577

578578
try:
579-
await async_sendall_stream(self.conn, message)
579+
await async_sendall(self.conn, message)
580580
except BaseException as error:
581581
self._raise_connection_failure(error)
582582

@@ -586,7 +586,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O
586586
If any exception is raised, the socket is closed.
587587
"""
588588
try:
589-
return await receive_message_stream(self, request_id, self.max_message_size)
589+
return await receive_message(self, request_id, self.max_message_size)
590590
except BaseException as error:
591591
self._raise_connection_failure(error)
592592

@@ -1107,7 +1107,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
11071107
)
11081108

11091109
try:
1110-
await async_sendall_stream(self, message)
1110+
await async_sendall(self, message)
11111111
except BaseException as error:
11121112
self._raise_connection_failure(error)
11131113

@@ -1117,7 +1117,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O
11171117
If any exception is raised, the socket is closed.
11181118
"""
11191119
try:
1120-
return await receive_message_stream(self, request_id, self.max_message_size)
1120+
return await receive_message(self, request_id, self.max_message_size)
11211121
except BaseException as error:
11221122
self._raise_connection_failure(error)
11231123

0 commit comments

Comments
 (0)