Skip to content

Commit 3e9d992

Browse files
committed
Spike
1 parent 39dcb1e commit 3e9d992

File tree

3 files changed

+104
-67
lines changed

3 files changed

+104
-67
lines changed

pymongo/asynchronous/network.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Internal network layer helper methods."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import datetime
1920
import logging
2021
import statistics
@@ -200,23 +201,23 @@ async def command_stream(
200201

201202
try:
202203
write_start = time.monotonic()
203-
await async_sendall_stream(conn.conn[1], msg)
204+
await async_sendall_stream(conn, msg)
204205
write_elapsed = time.monotonic() - write_start
205206
if use_op_msg and unacknowledged:
206207
# Unacknowledged, fake a successful command response.
207208
reply = None
208209
response_doc: _DocumentOut = {"ok": 1}
209210
else:
210211
read_start = time.monotonic()
211-
reply = await receive_message_stream(conn.conn[0], request_id)
212+
reply = await receive_message_stream(conn, request_id)
212213
read_elapsed = time.monotonic() - read_start
213-
# if name == "insert":
214-
# TOTAL.append(write_elapsed + read_elapsed)
215-
# TOTAL_READ.append(read_elapsed)
216-
# TOTAL_WRITE.append(write_elapsed)
217-
# if name == "endSessions":
218-
# print(
219-
# f"AVERAGE READ: {statistics.mean(TOTAL_READ)}, AVERAGE WRITE: {statistics.mean(TOTAL_WRITE)}, AVERAGE ELAPSED: {statistics.mean(TOTAL)}")
214+
if name == "insert":
215+
TOTAL.append(write_elapsed + read_elapsed)
216+
TOTAL_READ.append(read_elapsed)
217+
TOTAL_WRITE.append(write_elapsed)
218+
if name == "endSessions":
219+
print(
220+
f"AVERAGE READ: {statistics.mean(TOTAL_READ)}, AVERAGE WRITE: {statistics.mean(TOTAL_WRITE)}, AVERAGE ELAPSED: {statistics.mean(TOTAL)}")
220221
conn.more_to_come = reply.more_to_come
221222
unpacked_docs = reply.unpack_response(
222223
codec_options=codec_options, user_fields=user_fields
@@ -316,7 +317,7 @@ async def command_stream(
316317

317318

318319
async def receive_message_stream(
319-
conn: StreamReader, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
320+
conn: AsyncConnectionStream, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
320321
) -> Union[_OpReply, _OpMsg]:
321322
"""Receive a raw BSON message or raise socket.error."""
322323
# if _csot.get_timeout():
@@ -329,39 +330,40 @@ async def receive_message_stream(
329330
# deadline = None
330331
deadline = None
331332
# Ignore the response's request id.
332-
read_start = time.monotonic()
333-
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data_stream(conn, 16, deadline))
334-
read_elapsed = time.monotonic() - read_start
335-
# print(f"Read header in {read_elapsed}")
336-
# No request_id for exhaust cursor "getMore".
337-
if request_id is not None:
338-
if request_id != response_to:
339-
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
340-
if length <= 16:
341-
raise ProtocolError(
342-
f"Message length ({length!r}) not longer than standard message header size (16)"
343-
)
344-
if length > max_message_size:
345-
raise ProtocolError(
346-
f"Message length ({length!r}) is larger than server max "
347-
f"message size ({max_message_size!r})"
348-
)
349-
if op_code == 2012:
350-
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
351-
await async_receive_data_stream(conn, 9, deadline)
352-
)
353-
data = decompress(await async_receive_data_stream(conn, length - 25, deadline), compressor_id)
354-
else:
355-
read_start = time.monotonic()
356-
data = await async_receive_data_stream(conn, length - 16, deadline)
357-
read_elapsed = time.monotonic() - read_start
358-
# print(f"Read body in {read_elapsed}")
333+
loop = asyncio.get_running_loop()
334+
done = loop.create_future()
335+
mv = memoryview(bytearray(max_message_size))
336+
conn.conn[1].reset(mv, done)
337+
await asyncio.wait_for(done, timeout=None)
338+
length, op_code = done.result()
339+
340+
# length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data_stream(conn, 16, deadline))
341+
# # No request_id for exhaust cursor "getMore".
342+
# if request_id is not None:
343+
# if request_id != response_to:
344+
# raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
345+
# if length <= 16:
346+
# raise ProtocolError(
347+
# f"Message length ({length!r}) not longer than standard message header size (16)"
348+
# )
349+
# if length > max_message_size:
350+
# raise ProtocolError(
351+
# f"Message length ({length!r}) is larger than server max "
352+
# f"message size ({max_message_size!r})"
353+
# )
354+
# if op_code == 2012:
355+
# op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
356+
# await async_receive_data_stream(conn, 9, deadline)
357+
# )
358+
# data = decompress(await async_receive_data_stream(conn, length - 25, deadline), compressor_id)
359+
# else:
360+
# data = await async_receive_data_stream(conn, length - 16, deadline)
359361

360362
try:
361363
unpack_reply = _UNPACK_REPLY[op_code]
362364
except KeyError:
363365
raise ProtocolError(
364366
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
365367
) from None
366-
return unpack_reply(data)
368+
return unpack_reply(mv[16:length])
367369

pymongo/asynchronous/pool.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
ConnectionCheckOutFailedReason,
8181
ConnectionClosedReason,
8282
)
83-
from pymongo.network_layer import async_sendall_stream
83+
from pymongo.network_layer import async_sendall_stream, _UNPACK_HEADER, PyMongoProtocol
8484
from pymongo.pool_options import PoolOptions
8585
from pymongo.read_preferences import ReadPreference
8686
from pymongo.server_api import _add_to_command
@@ -794,7 +794,7 @@ class AsyncConnectionStream:
794794
"""
795795

796796
def __init__(
797-
self, conn: tuple[asyncio.StreamReader, asyncio.StreamWriter], pool: Pool, address: tuple[str, int], id: int
797+
self, conn: tuple[asyncio.BaseTransport, PyMongoProtocol], pool: Pool, address: tuple[str, int], id: int
798798
):
799799
self.pool_ref = weakref.ref(pool)
800800
self.conn = conn
@@ -1107,7 +1107,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
11071107
)
11081108

11091109
try:
1110-
await async_sendall_stream(self.conn[1], message)
1110+
await async_sendall_stream(self.conn, message)
11111111
except BaseException as error:
11121112
self._raise_connection_failure(error)
11131113

@@ -1117,7 +1117,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O
11171117
If any exception is raised, the socket is closed.
11181118
"""
11191119
try:
1120-
return await receive_message_stream(self.conn[0], request_id, self.max_message_size)
1120+
return await receive_message_stream(self.conn, request_id, self.max_message_size)
11211121
except BaseException as error:
11221122
self._raise_connection_failure(error)
11231123

@@ -1235,15 +1235,15 @@ def _close_conn(self) -> None:
12351235
# Note: We catch exceptions to avoid spurious errors on interpreter
12361236
# shutdown.
12371237
try:
1238-
self.conn[1].close()
1238+
self.conn[0].close()
12391239
except asyncio.CancelledError:
12401240
raise
12411241
except Exception: # noqa: S110
12421242
pass
12431243

12441244
def conn_closed(self) -> bool:
12451245
"""Return True if we know socket has been closed, False otherwise."""
1246-
return self.conn[1].is_closing()
1246+
return self.conn[0].is_closing()
12471247

12481248
def send_cluster_time(
12491249
self,
@@ -1315,11 +1315,6 @@ def __repr__(self) -> str:
13151315
)
13161316

13171317

1318-
async def _create_connection_stream(address: _Address, options: PoolOptions) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
1319-
"""Given (host, port) and PoolOptions, connect and return a paired StreamReader and StreamWriter.
1320-
"""
1321-
sock = _create_connection(address, options)
1322-
return await asyncio.open_connection(sock=sock)
13231318

13241319

13251320
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
@@ -1395,34 +1390,31 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
13951390

13961391
async def _configured_stream(
13971392
address: _Address, options: PoolOptions
1398-
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
1393+
) -> tuple[asyncio.BaseTransport, PyMongoProtocol]:
13991394
"""Given (host, port) and PoolOptions, return a configured socket.
14001395
14011396
Can raise socket.error, ConnectionFailure, or _CertificateError.
14021397
14031398
Sets socket's SSL and timeout options.
14041399
"""
1405-
(reader, writer) = await _create_connection_stream(address, options)
1400+
sock = _create_connection(address, options)
14061401
ssl_context = options._ssl_context
14071402

14081403
if ssl_context is None:
1409-
# sock.settimeout(options.socket_timeout)
1410-
return reader, writer
1404+
return await asyncio.get_running_loop().create_connection(lambda: PyMongoProtocol(), sock=sock)
14111405

14121406
host = address[0]
14131407
try:
14141408
# We have to pass hostname / ip address to wrap_socket
14151409
# to use SSLContext.check_hostname.
1416-
await writer.start_tls(ssl_context, server_hostname=host)
1410+
transport, protocol = await asyncio.get_running_loop().create_connection(lambda: PyMongoProtocol(), sock=sock, server_hostname=host, ssl=ssl_context)
14171411
except _CertificateError:
1418-
writer.close()
1419-
await writer.wait_closed()
1412+
transport.close()
14201413
# Raise _CertificateError directly like we do after match_hostname
14211414
# below.
14221415
raise
14231416
except (OSError, SSLError) as exc:
1424-
writer.close()
1425-
await writer.wait_closed()
1417+
transport.close()
14261418
# We raise AutoReconnect for transient and permanent SSL handshake
14271419
# failures alike. Permanent handshake failures, like protocol
14281420
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
@@ -1434,13 +1426,12 @@ async def _configured_stream(
14341426
and not options.tls_allow_invalid_hostnames
14351427
):
14361428
try:
1437-
ssl.match_hostname(writer.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined]
1429+
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined]
14381430
except _CertificateError:
1439-
writer.close()
1440-
await writer.wait_closed()
1431+
transport.close()
14411432
raise
14421433

1443-
return reader, writer
1434+
return transport, protocol
14441435

14451436

14461437
async def _configured_socket(

pymongo/network_layer.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,50 @@
7070
BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
7171

7272

73-
async def async_sendall_stream(stream: asyncio.StreamWriter, buf: bytes) -> None:
73+
class PyMongoProtocol(asyncio.Protocol):
74+
def __init__(self):
75+
self.transport = None
76+
self.done = None
77+
self.buffer = None
78+
self.expected_length = 0
79+
self.expecting_header = False
80+
self.bytes_read = 0
81+
self.op_code = None
82+
83+
def connection_made(self, transport):
84+
self.transport = transport
85+
86+
def write(self, message: bytes):
87+
self.transport.write(message)
88+
89+
def data_received(self, data):
90+
size = len(data)
91+
if size == 0:
92+
raise OSError("connection closed")
93+
self.buffer[self.bytes_read:self.bytes_read + size] = data
94+
self.bytes_read += size
95+
if self.expecting_header:
96+
self.expected_length, _, response_to, self.op_code = _UNPACK_HEADER(self.buffer[:16])
97+
self.expecting_header = False
98+
99+
if self.bytes_read == self.expected_length:
100+
self.done.set_result((self.expected_length, self.op_code))
101+
102+
def connection_lost(self, exc):
103+
if self.done and not self.done.done():
104+
self.done.set_result(True)
105+
106+
def reset(self, buffer: memoryview, done: asyncio.Future):
107+
self.buffer = buffer
108+
self.done = done
109+
self.bytes_read = 0
110+
self.expecting_header = True
111+
self.op_code = None
112+
113+
114+
async def async_sendall_stream(stream: AsyncConnectionStream, buf: bytes) -> None:
74115
try:
75-
stream.write(buf)
76-
await asyncio.wait_for(stream.drain(), timeout=None)
116+
stream.conn[1].write(buf)
77117
except asyncio.TimeoutError as exc:
78118
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
79119
raise socket.timeout("timed out") from exc
@@ -92,7 +132,7 @@ async def _poll_cancellation(conn: AsyncConnection) -> None:
92132

93133

94134
async def async_receive_data_stream(
95-
conn: StreamReader, length: int, deadline: Optional[float]
135+
conn: AsyncConnectionStream, length: int, deadline: Optional[float]
96136
) -> memoryview:
97137
# sock = conn.conn
98138
# sock_timeout = sock.gettimeout()
@@ -104,9 +144,13 @@ async def async_receive_data_stream(
104144
# timeout = max(deadline - time.monotonic(), 0)
105145
# else:
106146
# timeout = sock_timeout
147+
loop = asyncio.get_running_loop()
107148

149+
done = loop.create_future()
150+
conn.conn[1].setup(done, length)
108151
try:
109-
return await asyncio.wait_for(_async_receive_stream(conn, length), timeout=None)
152+
await asyncio.wait_for(done, timeout=None)
153+
return done.result()
110154
# read_task = create_task(_async_receive_stream(conn, length))
111155
# tasks = [read_task, cancellation_task]
112156
# done, pending = await asyncio.wait(

0 commit comments

Comments
 (0)