Skip to content

Commit 83e7e6b

Browse files
committed
Only one drain waiter
1 parent 0f165b7 commit 83e7e6b

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

pymongo/network_layer.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import struct
2424
import sys
2525
import time
26+
import yappi
2627
from asyncio import AbstractEventLoop, Future, StreamReader
2728
from typing import (
2829
TYPE_CHECKING,
@@ -84,7 +85,7 @@ def __init__(self):
8485
self._done = None
8586
self._connection_lost = False
8687
self._paused = False
87-
self._drain_waiters = collections.deque()
88+
self._drain_waiter = None
8889
self._loop = asyncio.get_running_loop()
8990

9091
def connection_made(self, transport):
@@ -104,10 +105,11 @@ def get_buffer(self, sizehint: int):
104105

105106
def buffer_updated(self, nbytes: int):
106107
if nbytes == 0:
107-
raise OSError("connection closed")
108+
self.connection_lost(OSError("connection closed"))
109+
self._done.set_result(None)
108110
self.bytes_read += nbytes
109111
if self.expecting_header:
110-
self.expected_length, _, response_to, self.op_code = _UNPACK_HEADER(self._buffer[:16])
112+
self.expected_length, _, _, self.op_code = _UNPACK_HEADER(self._buffer[:16])
111113
self.expecting_header = False
112114

113115
if self.bytes_read == self.expected_length:
@@ -121,34 +123,28 @@ def resume_writing(self):
121123
assert self._paused
122124
self._paused = False
123125

124-
for waiter in self._drain_waiters:
125-
if not waiter.done():
126-
waiter.set_result(None)
126+
if self._drain_waiter and not self._drain_waiter.done():
127+
self._drain_waiter.set_result(None)
127128

128129
def connection_lost(self, exc):
129130
self._connection_lost = True
130131
# Wake up the writer(s) if currently paused.
131132
if not self._paused:
132133
return
133134

134-
for waiter in self._drain_waiters:
135-
if not waiter.done():
136-
if exc is None:
137-
waiter.set_result(None)
138-
else:
139-
waiter.set_exception(exc)
135+
if self._drain_waiter and not self._drain_waiter.done():
136+
if exc is None:
137+
self._drain_waiter.set_result(None)
138+
else:
139+
self._drain_waiter.set_exception(exc)
140140

141141
async def _drain_helper(self):
142142
if self._connection_lost:
143143
raise ConnectionResetError('Connection lost')
144144
if not self._paused:
145145
return
146-
waiter = self._loop.create_future()
147-
self._drain_waiters.append(waiter)
148-
try:
149-
await waiter
150-
finally:
151-
self._drain_waiters.remove(waiter)
146+
self._drain_waiter = self._loop.create_future()
147+
await self._drain_waiter
152148

153149
def reset(self, buffer: memoryview):
154150
self._buffer = buffer

0 commit comments

Comments
 (0)