|
16 | 16 | from __future__ import annotations
|
17 | 17 |
|
18 | 18 | import asyncio
|
| 19 | +import collections |
19 | 20 | import errno
|
20 | 21 | import socket
|
21 | 22 | import statistics
|
|
70 | 71 | BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
|
71 | 72 |
|
72 | 73 |
|
73 |
| -class PyMongoProtocol(asyncio.Protocol): |
| 74 | +class PyMongoProtocol(asyncio.BufferedProtocol): |
74 | 75 | def __init__(self):
|
75 | 76 | self.transport = None
|
76 |
| - self.done = None |
77 |
| - self.buffer = None |
| 77 | + self._buffer = None |
78 | 78 | self.expected_length = 0
|
79 | 79 | self.expecting_header = False
|
80 | 80 | self.bytes_read = 0
|
81 | 81 | self.op_code = None
|
| 82 | + self._done = None |
| 83 | + self._connection_lost = False |
| 84 | + self._paused = False |
| 85 | + self._drain_waiters = collections.deque() |
| 86 | + self._loop = asyncio.get_running_loop() |
82 | 87 |
|
83 | 88 | def connection_made(self, transport):
|
84 | 89 | self.transport = transport
|
85 | 90 |
|
86 |
| - def write(self, message: bytes): |
| 91 | + async def write(self, message: bytes): |
87 | 92 | self.transport.write(message)
|
| 93 | + await self._drain_helper() |
88 | 94 |
|
89 |
| - def data_received(self, data): |
90 |
| - size = len(data) |
91 |
| - if size == 0: |
| 95 | + async def read(self): |
| 96 | + self._done = self._loop.create_future() |
| 97 | + await self._done |
| 98 | + return self.expected_length, self.op_code |
| 99 | + |
| 100 | + def get_buffer(self, sizehint: int): |
| 101 | + return self._buffer[self.bytes_read:] |
| 102 | + |
| 103 | + def buffer_updated(self, nbytes: int): |
| 104 | + if nbytes == 0: |
92 | 105 | raise OSError("connection closed")
|
93 |
| - self.buffer[self.bytes_read:self.bytes_read + size] = data |
94 |
| - self.bytes_read += size |
| 106 | + self.bytes_read += nbytes |
95 | 107 | if self.expecting_header:
|
96 |
| - self.expected_length, _, response_to, self.op_code = _UNPACK_HEADER(self.buffer[:16]) |
| 108 | + self.expected_length, _, response_to, self.op_code = _UNPACK_HEADER(self._buffer[:16]) |
97 | 109 | self.expecting_header = False
|
98 | 110 |
|
99 | 111 | if self.bytes_read == self.expected_length:
|
100 |
| - self.done.set_result((self.expected_length, self.op_code)) |
| 112 | + self._done.set_result((self.expected_length, self.op_code)) |
| 113 | + |
| 114 | + def pause_writing(self): |
| 115 | + assert not self._paused |
| 116 | + self._paused = True |
| 117 | + |
| 118 | + def resume_writing(self): |
| 119 | + assert self._paused |
| 120 | + self._paused = False |
| 121 | + |
| 122 | + for waiter in self._drain_waiters: |
| 123 | + if not waiter.done(): |
| 124 | + waiter.set_result(None) |
101 | 125 |
|
102 | 126 | def connection_lost(self, exc):
|
103 |
| - if self.done and not self.done.done(): |
104 |
| - self.done.set_result(True) |
| 127 | + self._connection_lost = True |
| 128 | + # Wake up the writer(s) if currently paused. |
| 129 | + if not self._paused: |
| 130 | + return |
105 | 131 |
|
106 |
| - def reset(self, buffer: memoryview, done: asyncio.Future): |
107 |
| - self.buffer = buffer |
108 |
| - self.done = done |
| 132 | + for waiter in self._drain_waiters: |
| 133 | + if not waiter.done(): |
| 134 | + if exc is None: |
| 135 | + waiter.set_result(None) |
| 136 | + else: |
| 137 | + waiter.set_exception(exc) |
| 138 | + |
| 139 | + async def _drain_helper(self): |
| 140 | + if self._connection_lost: |
| 141 | + raise ConnectionResetError('Connection lost') |
| 142 | + if not self._paused: |
| 143 | + return |
| 144 | + waiter = self._loop.create_future() |
| 145 | + self._drain_waiters.append(waiter) |
| 146 | + try: |
| 147 | + await waiter |
| 148 | + finally: |
| 149 | + self._drain_waiters.remove(waiter) |
| 150 | + |
| 151 | + def reset(self, buffer: memoryview): |
| 152 | + self._buffer = buffer |
109 | 153 | self.bytes_read = 0
|
110 | 154 | self.expecting_header = True
|
111 | 155 | self.op_code = None
|
112 | 156 |
|
| 157 | + def data(self): |
| 158 | + return self._buffer |
| 159 | + |
113 | 160 |
|
114 | 161 | async def async_sendall_stream(stream: AsyncConnectionStream, buf: bytes) -> None:
|
115 | 162 | try:
|
116 |
| - stream.conn[1].write(buf) |
| 163 | + await asyncio.wait_for(stream.conn[1].write(buf), timeout=None) |
117 | 164 | except asyncio.TimeoutError as exc:
|
118 | 165 | # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
|
119 | 166 | raise socket.timeout("timed out") from exc
|
@@ -145,9 +192,8 @@ async def async_receive_data_stream(
|
145 | 192 | # else:
|
146 | 193 | # timeout = sock_timeout
|
147 | 194 | loop = asyncio.get_running_loop()
|
148 |
| - |
149 | 195 | done = loop.create_future()
|
150 |
| - conn.conn[1].setup(done, length) |
| 196 | + conn.conn[1].reset(done, length) |
151 | 197 | try:
|
152 | 198 | await asyncio.wait_for(done, timeout=None)
|
153 | 199 | return done.result()
|
|
0 commit comments