Skip to content

Commit 51b6537

Browse files
committed
Use sliding buffer for protocols
1 parent f638c04 commit 51b6537

File tree

2 files changed

+60
-11
lines changed

2 files changed

+60
-11
lines changed

pymongo/asynchronous/network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ async def receive_message_stream(
332332
# Ignore the response's request id.
333333
# data = bytearray(max_message_size)
334334
conn.conn[1].reset()
335-
length, op_code = await asyncio.wait_for(conn.conn[1].read(), timeout=None)
335+
data, op_code = await asyncio.wait_for(conn.conn[1].read(), timeout=None)
336336

337337
# length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data_stream(conn, 16, deadline))
338338
# # No request_id for exhaust cursor "getMore".
@@ -362,5 +362,5 @@ async def receive_message_stream(
362362
raise ProtocolError(
363363
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
364364
) from None
365-
return unpack_reply(conn.conn[1].data()[16:length])
365+
return unpack_reply(data)
366366

pymongo/network_layer.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,15 @@ def __init__(self):
8080
self._buffer = memoryview(bytearray(MAX_MESSAGE_SIZE))
8181
self.expected_length = 0
8282
self.expecting_header = False
83-
self.bytes_read = 0
83+
self.ready_offset = 0
84+
self.empty_offset = 0
8485
self.op_code = None
8586
self._done = None
8687
self._connection_lost = False
8788
self._paused = False
8889
self._drain_waiter = None
8990
self._loop = asyncio.get_running_loop()
91+
self._messages = collections.deque()
9092

9193
def connection_made(self, transport):
9294
self.transport = transport
@@ -96,24 +98,71 @@ async def write(self, message: bytes):
9698
await self._drain_helper()
9799

98100
async def read(self):
99-
self._done = self._loop.create_future()
100-
await self._done
101-
return self.expected_length, self.op_code
101+
data, opcode, to_remove = None, None, None
102+
for message in self._messages:
103+
if message.done():
104+
data, opcode = self.unpack_message(message)
105+
to_remove = message
106+
if to_remove:
107+
self._messages.remove(to_remove)
108+
else:
109+
message = self._loop.create_future()
110+
self._messages.append(message)
111+
try:
112+
await message
113+
finally:
114+
self._messages.remove(message)
115+
data, opcode = self.unpack_message(message)
116+
return data, opcode
117+
118+
def unpack_message(self, message):
119+
start, end, opcode = message.result()
120+
if isinstance(start, tuple):
121+
return memoryview(
122+
self._buffer[start[0]:end[0]].tobytes() + self._buffer[start[1]:end[1]].tobytes()), opcode
123+
else:
124+
return self._buffer[start:end], opcode
102125

103126
def get_buffer(self, sizehint: int):
104-
return self._buffer[self.bytes_read:]
127+
if self.empty_offset + sizehint >= MAX_MESSAGE_SIZE - 1:
128+
self.empty_offset = 0
129+
if self.empty_offset < self.ready_offset:
130+
return self._buffer[self.empty_offset:self.ready_offset]
131+
else:
132+
return self._buffer[self.empty_offset:]
105133

106134
def buffer_updated(self, nbytes: int):
107135
if nbytes == 0:
108136
self.connection_lost(OSError("connection closed"))
109137
self._done.set_result(None)
110-
self.bytes_read += nbytes
138+
self.empty_offset += nbytes
111139
if self.expecting_header:
112-
self.expected_length, _, _, self.op_code = _UNPACK_HEADER(self._buffer[:16])
140+
self.expected_length, _, _, self.op_code = _UNPACK_HEADER(self._buffer[self.ready_offset:self.ready_offset + 16])
113141
self.expecting_header = False
114142

115-
if self.bytes_read == self.expected_length:
116-
self._done.set_result((self.expected_length, self.op_code))
143+
if self.ready_offset < self.empty_offset:
144+
if self.empty_offset - self.ready_offset >= self.expected_length:
145+
self.store_message(self.ready_offset + 16, self.ready_offset + self.expected_length, self.op_code)
146+
self.ready_offset += self.expected_length
147+
else:
148+
if self.ready_offset + self.expected_length <= MAX_MESSAGE_SIZE - 1:
149+
self.store_message(self.ready_offset + 16, self.ready_offset + self.expected_length, self.op_code)
150+
self.ready_offset += self.expected_length
151+
elif MAX_MESSAGE_SIZE - 1 - self.ready_offset + self.empty_offset >= self.expected_length:
152+
self.store_message((self.ready_offset, 0), (MAX_MESSAGE_SIZE - 1, self.expected_length - (MAX_MESSAGE_SIZE - 1 - self.ready_offset)), self.op_code)
153+
self.ready_offset = self.expected_length - (MAX_MESSAGE_SIZE - 1 - self.ready_offset)
154+
155+
def store_message(self, start, end, opcode):
156+
stored = False
157+
for message in self._messages:
158+
if not message.done():
159+
message.set_result((start, end, opcode))
160+
stored = True
161+
if not stored:
162+
message = self._loop.create_future()
163+
message.set_result((start, end, opcode))
164+
self._messages.append(message)
165+
self.expecting_header = True
117166

118167
def pause_writing(self):
119168
assert not self._paused

0 commit comments

Comments
 (0)