Skip to content

Commit 4853245

Browse files
committed
Test BufferedProtocol
1 parent 3e9d992 commit 4853245

File tree

3 files changed

+69
-26
lines changed

3 files changed

+69
-26
lines changed

pymongo/asynchronous/network.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,9 @@ async def receive_message_stream(
330330
# deadline = None
331331
deadline = None
332332
# Ignore the response's request id.
333-
loop = asyncio.get_running_loop()
334-
done = loop.create_future()
335333
mv = memoryview(bytearray(max_message_size))
336-
conn.conn[1].reset(mv, done)
337-
await asyncio.wait_for(done, timeout=None)
338-
length, op_code = done.result()
334+
conn.conn[1].reset(mv)
335+
length, op_code = await asyncio.wait_for(conn.conn[1].read(), timeout=None)
339336

340337
# length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data_stream(conn, 16, deadline))
341338
# # No request_id for exhaust cursor "getMore".

pymongo/asynchronous/pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
11071107
)
11081108

11091109
try:
1110-
await async_sendall_stream(self.conn, message)
1110+
await async_sendall_stream(self, message)
11111111
except BaseException as error:
11121112
self._raise_connection_failure(error)
11131113

@@ -1117,7 +1117,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O
11171117
If any exception is raised, the socket is closed.
11181118
"""
11191119
try:
1120-
return await receive_message_stream(self.conn, request_id, self.max_message_size)
1120+
return await receive_message_stream(self, request_id, self.max_message_size)
11211121
except BaseException as error:
11221122
self._raise_connection_failure(error)
11231123

pymongo/network_layer.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import asyncio
19+
import collections
1920
import errno
2021
import socket
2122
import statistics
@@ -70,50 +71,96 @@
7071
BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
7172

7273

73-
class PyMongoProtocol(asyncio.Protocol):
74+
class PyMongoProtocol(asyncio.BufferedProtocol):
7475
def __init__(self):
7576
self.transport = None
76-
self.done = None
77-
self.buffer = None
77+
self._buffer = None
7878
self.expected_length = 0
7979
self.expecting_header = False
8080
self.bytes_read = 0
8181
self.op_code = None
82+
self._done = None
83+
self._connection_lost = False
84+
self._paused = False
85+
self._drain_waiters = collections.deque()
86+
self._loop = asyncio.get_running_loop()
8287

8388
def connection_made(self, transport):
8489
self.transport = transport
8590

86-
def write(self, message: bytes):
91+
async def write(self, message: bytes):
8792
self.transport.write(message)
93+
await self._drain_helper()
8894

89-
def data_received(self, data):
90-
size = len(data)
91-
if size == 0:
95+
async def read(self):
96+
self._done = self._loop.create_future()
97+
await self._done
98+
return self.expected_length, self.op_code
99+
100+
def get_buffer(self, sizehint: int):
101+
return self._buffer[self.bytes_read:]
102+
103+
def buffer_updated(self, nbytes: int):
104+
if nbytes == 0:
92105
raise OSError("connection closed")
93-
self.buffer[self.bytes_read:self.bytes_read + size] = data
94-
self.bytes_read += size
106+
self.bytes_read += nbytes
95107
if self.expecting_header:
96-
self.expected_length, _, response_to, self.op_code = _UNPACK_HEADER(self.buffer[:16])
108+
self.expected_length, _, response_to, self.op_code = _UNPACK_HEADER(self._buffer[:16])
97109
self.expecting_header = False
98110

99111
if self.bytes_read == self.expected_length:
100-
self.done.set_result((self.expected_length, self.op_code))
112+
self._done.set_result((self.expected_length, self.op_code))
113+
114+
def pause_writing(self):
115+
assert not self._paused
116+
self._paused = True
117+
118+
def resume_writing(self):
119+
assert self._paused
120+
self._paused = False
121+
122+
for waiter in self._drain_waiters:
123+
if not waiter.done():
124+
waiter.set_result(None)
101125

102126
def connection_lost(self, exc):
103-
if self.done and not self.done.done():
104-
self.done.set_result(True)
127+
self._connection_lost = True
128+
# Wake up the writer(s) if currently paused.
129+
if not self._paused:
130+
return
105131

106-
def reset(self, buffer: memoryview, done: asyncio.Future):
107-
self.buffer = buffer
108-
self.done = done
132+
for waiter in self._drain_waiters:
133+
if not waiter.done():
134+
if exc is None:
135+
waiter.set_result(None)
136+
else:
137+
waiter.set_exception(exc)
138+
139+
async def _drain_helper(self):
140+
if self._connection_lost:
141+
raise ConnectionResetError('Connection lost')
142+
if not self._paused:
143+
return
144+
waiter = self._loop.create_future()
145+
self._drain_waiters.append(waiter)
146+
try:
147+
await waiter
148+
finally:
149+
self._drain_waiters.remove(waiter)
150+
151+
def reset(self, buffer: memoryview):
152+
self._buffer = buffer
109153
self.bytes_read = 0
110154
self.expecting_header = True
111155
self.op_code = None
112156

157+
def data(self):
158+
return self._buffer
159+
113160

114161
async def async_sendall_stream(stream: AsyncConnectionStream, buf: bytes) -> None:
115162
try:
116-
stream.conn[1].write(buf)
163+
await asyncio.wait_for(stream.conn[1].write(buf), timeout=None)
117164
except asyncio.TimeoutError as exc:
118165
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
119166
raise socket.timeout("timed out") from exc
@@ -145,9 +192,8 @@ async def async_receive_data_stream(
145192
# else:
146193
# timeout = sock_timeout
147194
loop = asyncio.get_running_loop()
148-
149195
done = loop.create_future()
150-
conn.conn[1].setup(done, length)
196+
conn.conn[1].reset(done, length)
151197
try:
152198
await asyncio.wait_for(done, timeout=None)
153199
return done.result()

0 commit comments

Comments
 (0)