Skip to content

Commit 2d0f4c1

Browse files
committed
Final POC for Wire Protocol-handling protocol
1 parent 55dfaca commit 2d0f4c1

File tree

8 files changed

+853
-281
lines changed

8 files changed

+853
-281
lines changed

pymongo/asynchronous/network.py

Lines changed: 7 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,8 @@
1515
"""Internal network layer helper methods."""
1616
from __future__ import annotations
1717

18-
import asyncio
1918
import datetime
2019
import logging
21-
import statistics
22-
import time
23-
from asyncio import streams, StreamReader
2420
from typing import (
2521
TYPE_CHECKING,
2622
Any,
@@ -34,26 +30,24 @@
3430

3531
from bson import _decode_all_selective
3632
from pymongo import _csot, helpers_shared, message
37-
from pymongo.common import MAX_MESSAGE_SIZE
38-
from pymongo.compression_support import _NO_COMPRESSION, decompress
33+
from pymongo.compression_support import _NO_COMPRESSION
3934
from pymongo.errors import (
4035
NotPrimaryError,
4136
OperationFailure,
42-
ProtocolError,
4337
)
4438
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
45-
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
39+
from pymongo.message import _OpMsg
4640
from pymongo.monitoring import _is_speculative_authenticate
4741
from pymongo.network_layer import (
48-
_UNPACK_COMPRESSION_HEADER,
49-
_UNPACK_HEADER, async_sendall, async_receive_data,
42+
async_receive_message,
43+
async_sendall,
5044
)
5145

5246
if TYPE_CHECKING:
5347
from bson import CodecOptions
5448
from pymongo.asynchronous.client_session import AsyncClientSession
5549
from pymongo.asynchronous.mongo_client import AsyncMongoClient
56-
from pymongo.asynchronous.pool import AsyncConnection, AsyncStreamConnection, AsyncConnectionProtocol
50+
from pymongo.asynchronous.pool import AsyncConnectionProtocol
5751
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
5852
from pymongo.monitoring import _EventListeners
5953
from pymongo.read_concern import ReadConcern
@@ -63,13 +57,8 @@
6357

6458
_IS_SYNC = False
6559

66-
# TOTAL = []
67-
# TOTAL_WRITE = []
68-
# TOTAL_READ = []
69-
# print(f"TOTALS: {TOTAL, TOTAL_WRITE, TOTAL_READ}")
7060

71-
72-
async def command_stream(
61+
async def command(
7362
conn: AsyncConnectionProtocol,
7463
dbname: str,
7564
spec: MutableMapping[str, Any],
@@ -200,24 +189,13 @@ async def command_stream(
200189
)
201190

202191
try:
203-
write_start = time.monotonic()
204192
await async_sendall(conn, msg)
205-
write_elapsed = time.monotonic() - write_start
206193
if use_op_msg and unacknowledged:
207194
# Unacknowledged, fake a successful command response.
208195
reply = None
209196
response_doc: _DocumentOut = {"ok": 1}
210197
else:
211-
read_start = time.monotonic()
212-
reply = await receive_message(conn, request_id)
213-
read_elapsed = time.monotonic() - read_start
214-
# if name == "insert":
215-
# TOTAL.append(write_elapsed + read_elapsed)
216-
# TOTAL_READ.append(read_elapsed)
217-
# TOTAL_WRITE.append(write_elapsed)
218-
# if name == "endSessions":
219-
# print(
220-
# f"AVERAGE READ: {statistics.mean(TOTAL_READ)}, AVERAGE WRITE: {statistics.mean(TOTAL_WRITE)}, AVERAGE ELAPSED: {statistics.mean(TOTAL)}")
198+
reply = await async_receive_message(conn, request_id)
221199
conn.more_to_come = reply.more_to_come
222200
unpacked_docs = reply.unpack_response(
223201
codec_options=codec_options, user_fields=user_fields
@@ -314,48 +292,3 @@ async def command_stream(
314292
)
315293

316294
return response_doc # type: ignore[return-value]
317-
318-
319-
async def receive_message(
320-
conn: AsyncConnectionProtocol, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
321-
) -> Union[_OpReply, _OpMsg]:
322-
"""Receive a raw BSON message or raise socket.error."""
323-
# if _csot.get_timeout():
324-
# deadline = _csot.get_deadline()
325-
# else:
326-
# timeout = conn.conn.gettimeout()
327-
# if timeout:
328-
# deadline = time.monotonic() + timeout
329-
# else:
330-
# deadline = None
331-
deadline = None
332-
# Ignore the response's request id.
333-
data, op_code = await async_receive_data(conn, 0, deadline)
334-
# length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
335-
# No request_id for exhaust cursor "getMore".
336-
# if request_id is not None:
337-
# if request_id != response_to:
338-
# raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
339-
# if length <= 16:
340-
# raise ProtocolError(
341-
# f"Message length ({length!r}) not longer than standard message header size (16)"
342-
# )
343-
# if length > max_message_size:
344-
# raise ProtocolError(
345-
# f"Message length ({length!r}) is larger than server max "
346-
# f"message size ({max_message_size!r})"
347-
# )
348-
# if op_code == 2012:
349-
# op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
350-
# await async_receive_data(conn, 9, deadline)
351-
# )
352-
# data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
353-
# else:
354-
# data = await async_receive_data(conn, length - 16, deadline)
355-
try:
356-
unpack_reply = _UNPACK_REPLY[op_code]
357-
except KeyError:
358-
raise ProtocolError(
359-
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
360-
) from None
361-
return unpack_reply(data)

pymongo/asynchronous/pool.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@
3737
Union,
3838
)
3939

40-
from asyncio import streams
4140
from bson import DEFAULT_CODEC_OPTIONS
4241
from pymongo import _csot, helpers_shared
4342
from pymongo.asynchronous.client_session import _validate_session_write_concern
4443
from pymongo.asynchronous.helpers import _handle_reauth
45-
from pymongo.asynchronous.network import command_stream, receive_message
44+
from pymongo.asynchronous.network import command
4645
from pymongo.common import (
4746
MAX_BSON_SIZE,
4847
MAX_MESSAGE_SIZE,
@@ -80,7 +79,7 @@
8079
ConnectionCheckOutFailedReason,
8180
ConnectionClosedReason,
8281
)
83-
from pymongo.network_layer import async_sendall, _UNPACK_HEADER, PyMongoProtocol
82+
from pymongo.network_layer import PyMongoProtocol, async_receive_message, async_sendall
8483
from pymongo.pool_options import PoolOptions
8584
from pymongo.read_preferences import ReadPreference
8685
from pymongo.server_api import _add_to_command
@@ -534,7 +533,7 @@ async def command(
534533
if self.op_msg_enabled:
535534
self._raise_if_not_writable(unacknowledged)
536535
try:
537-
return await command_stream(
536+
return await command(
538537
self,
539538
dbname,
540539
spec,
@@ -578,7 +577,6 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
578577
try:
579578
await async_sendall(self.conn, message)
580579
except BaseException as error:
581-
print(error)
582580
self._raise_connection_failure(error)
583581

584582
async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
@@ -587,7 +585,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O
587585
If any exception is raised, the socket is closed.
588586
"""
589587
try:
590-
return await receive_message(self, request_id, self.max_message_size)
588+
return await async_receive_message(self, request_id, self.max_message_size)
591589
except BaseException as error:
592590
self._raise_connection_failure(error)
593591

@@ -795,7 +793,11 @@ class AsyncConnectionProtocol:
795793
"""
796794

797795
def __init__(
798-
self, conn: tuple[asyncio.BaseTransport, PyMongoProtocol], pool: Pool, address: tuple[str, int], id: int
796+
self,
797+
conn: tuple[asyncio.BaseTransport, PyMongoProtocol],
798+
pool: Pool,
799+
address: tuple[str, int],
800+
id: int,
799801
):
800802
self.pool_ref = weakref.ref(pool)
801803
self.conn = conn
@@ -1066,7 +1068,7 @@ async def command(
10661068
if self.op_msg_enabled:
10671069
self._raise_if_not_writable(unacknowledged)
10681070
try:
1069-
return await command_stream(
1071+
return await command(
10701072
self,
10711073
dbname,
10721074
spec,
@@ -1118,7 +1120,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O
11181120
If any exception is raised, the socket is closed.
11191121
"""
11201122
try:
1121-
return await receive_message(self, request_id, self.max_message_size)
1123+
return await async_receive_message(self, request_id, self.max_message_size)
11221124
except BaseException as error:
11231125
self._raise_connection_failure(error)
11241126

@@ -1316,8 +1318,6 @@ def __repr__(self) -> str:
13161318
)
13171319

13181320

1319-
1320-
13211321
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
13221322
"""Given (host, port) and PoolOptions, connect and return a socket object.
13231323
@@ -1400,15 +1400,23 @@ async def _configured_stream(
14001400
"""
14011401
sock = _create_connection(address, options)
14021402
ssl_context = options._ssl_context
1403+
timeout = sock.gettimeout()
14031404

14041405
if ssl_context is None:
1405-
return await asyncio.get_running_loop().create_connection(lambda: PyMongoProtocol(), sock=sock)
1406+
return await asyncio.get_running_loop().create_connection(
1407+
lambda: PyMongoProtocol(timeout=timeout, buffer_size=2**16), sock=sock
1408+
)
14061409

14071410
host = address[0]
14081411
try:
14091412
# We have to pass hostname / ip address to wrap_socket
14101413
# to use SSLContext.check_hostname.
1411-
transport, protocol = await asyncio.get_running_loop().create_connection(lambda: PyMongoProtocol(), sock=sock, server_hostname=host, ssl=ssl_context)
1414+
transport, protocol = await asyncio.get_running_loop().create_connection(
1415+
lambda: PyMongoProtocol(timeout=timeout, buffer_size=2**14),
1416+
sock=sock,
1417+
server_hostname=host,
1418+
ssl=ssl_context,
1419+
)
14121420
except _CertificateError:
14131421
transport.close()
14141422
# Raise _CertificateError directly like we do after match_hostname
@@ -1819,7 +1827,9 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
18191827
self.requests -= 1
18201828
self.size_cond.notify()
18211829

1822-
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnectionProtocol:
1830+
async def connect(
1831+
self, handler: Optional[_MongoClientErrorHandler] = None
1832+
) -> AsyncConnectionProtocol:
18231833
"""Connect to Mongo and return a new AsyncConnection.
18241834
18251835
Can raise ConnectionFailure.
@@ -1849,7 +1859,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
18491859
)
18501860

18511861
try:
1852-
sock = await _configured_stream(self.address, self.opts)
1862+
transport, protocol = await _configured_stream(self.address, self.opts)
18531863
except BaseException as error:
18541864
async with self.lock:
18551865
self.active_contexts.discard(tmp_context)
@@ -1875,7 +1885,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
18751885

18761886
raise
18771887

1878-
conn = AsyncConnectionProtocol(sock, self, self.address, conn_id) # type: ignore[arg-type]
1888+
conn = AsyncConnectionProtocol((transport, protocol), self, self.address, conn_id) # type: ignore[arg-type]
18791889
async with self.lock:
18801890
self.active_contexts.add(conn.cancel_context)
18811891
self.active_contexts.discard(tmp_context)

pymongo/asynchronous/server.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from __future__ import annotations
1717

1818
import logging
19-
import statistics
20-
import time
2119
from datetime import datetime
2220
from typing import (
2321
TYPE_CHECKING,
@@ -60,12 +58,6 @@
6058
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
6159

6260

63-
# TOTAL = []
64-
# TOTAL_WRITE = []
65-
# TOTAL_READ = []
66-
# print(f"TOTALS: {TOTAL, TOTAL_WRITE, TOTAL_READ}")
67-
68-
6961
class Server:
7062
def __init__(
7163
self,
@@ -212,19 +204,8 @@ async def run_operation(
212204
if more_to_come:
213205
reply = await conn.receive_message(None)
214206
else:
215-
write_start = time.monotonic()
216207
await conn.send_message(data, max_doc_size)
217-
write_elapsed = time.monotonic() - write_start
218-
219-
read_start = time.monotonic()
220208
reply = await conn.receive_message(request_id)
221-
read_elapsed = time.monotonic() - read_start
222-
223-
# TOTAL.append(write_elapsed + read_elapsed)
224-
# TOTAL_READ.append(read_elapsed)
225-
# TOTAL_WRITE.append(write_elapsed)
226-
# print(
227-
# f"AVERAGE READ: {statistics.mean(TOTAL_READ)}, AVERAGE WRITE: {statistics.mean(TOTAL_WRITE)}, AVERAGE ELAPSED: {statistics.mean(TOTAL)}")
228209

229210
# Unpack and check for command errors.
230211
if use_cmd:

pymongo/connection.py

Whitespace-only changes.

pymongo/message.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
"""
2222
from __future__ import annotations
2323

24-
import asyncio
2524
import datetime
2625
import random
2726
import struct
@@ -1547,7 +1546,9 @@ def unpack(cls, msg: bytes) -> _OpMsg:
15471546
raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}")
15481547

15491548
if len(msg) != first_payload_size + 5:
1550-
raise ProtocolError(f"Unsupported OP_MSG reply: >1 section, {len(msg)} vs {first_payload_size + 5}")
1549+
raise ProtocolError(
1550+
f"Unsupported OP_MSG reply: >1 section, {len(msg)} vs {first_payload_size + 5}"
1551+
)
15511552

15521553
payload_document = msg[5:]
15531554
return cls(flags, payload_document)

0 commit comments

Comments
 (0)