1616from __future__ import annotations
1717
1818import asyncio
19+ import collections
1920import errno
2021import socket
2122import struct
@@ -141,6 +142,7 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: Optional[int] =
141142 self .transport = None
142143 self ._buffer = memoryview (bytearray (self ._buffer_size ))
143144 self ._overflow = None
145+ self ._start = 0
144146 self ._length = 0
145147 self ._overflow_length = 0
146148 self ._body_length = 0
@@ -157,7 +159,9 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: Optional[int] =
157159 self ._request_id = None
158160 self ._closed = asyncio .get_running_loop ().create_future ()
159161 self ._debug = False
160-
162+ self ._expecting_header = True
163+ self ._pending_messages = collections .deque ()
164+ self ._done_messages = collections .deque ()
161165
162166 def settimeout (self , timeout : float | None ):
163167 self ._timeout = timeout
@@ -182,24 +186,31 @@ async def write(self, message: bytes):
182186
183187 async def read (self , request_id : Optional [int ], max_message_size : int , debug : bool = False ):
184188 """Read a single MongoDB Wire Protocol message from this connection."""
185- self ._debug = debug
186- self ._max_message_size = max_message_size
187- self ._request_id = request_id
188- self ._length , self ._overflow_length , self ._body_length , self ._op_code , self ._overflow = (
189- 0 ,
190- 0 ,
191- 0 ,
192- None ,
193- None ,
194- )
195- if self .transport .is_closing ():
196- print ("Connection is closed" )
197- raise OSError ("Connection is closed" )
198- self ._read_waiter = asyncio .get_running_loop ().create_future ()
199- await self ._read_waiter
200- if self ._read_waiter .done () and self ._read_waiter .result ():
201- if self ._debug :
202- print ("Read waiter done" )
189+ if self ._done_messages :
190+ message = await self ._done_messages .popleft ()
191+ else :
192+ self ._expecting_header = True
193+ self ._debug = debug
194+ self ._max_message_size = max_message_size
195+ self ._request_id = request_id
196+ self ._length , self ._overflow_length , self ._body_length , self ._op_code , self ._overflow = (
197+ 0 ,
198+ 0 ,
199+ 0 ,
200+ None ,
201+ None ,
202+ )
203+ if self .transport .is_closing ():
204+ raise OSError ("Connection is closed" )
205+ read_waiter = asyncio .get_running_loop ().create_future ()
206+ self ._pending_messages .append (read_waiter )
207+ try :
208+ message = await read_waiter
209+ finally :
210+ if read_waiter in self ._done_messages :
211+ self ._done_messages .remove (read_waiter )
212+ if message :
213+ start , end = message [0 ], message [1 ]
203214 header_size = 16
204215 if self ._body_length > self ._buffer_size :
205216 if self ._is_compressed :
@@ -220,21 +231,17 @@ async def read(self, request_id: Optional[int], max_message_size: int, debug: bo
220231 if self ._is_compressed :
221232 header_size = 25
222233 return decompress (
223- memoryview (self ._buffer [header_size : self . _body_length ]),
234+ memoryview (self ._buffer [start + header_size : end ]),
224235 self ._compressor_id ,
225236 ), self ._op_code
226237 else :
227- return memoryview (self ._buffer [header_size : self . _body_length ]), self ._op_code
238+ return memoryview (self ._buffer [start + header_size : end ]), self ._op_code
228239 raise OSError ("connection closed" )
229240
230241 def get_buffer (self , sizehint : int ):
231242 """Called to allocate a new receive buffer."""
232243 if self ._overflow is not None :
233- if len (self ._overflow [self ._overflow_length :]) == 0 :
234- print (f"Overflow buffer overflow, overflow size of { len (self ._overflow )} " )
235244 return self ._overflow [self ._overflow_length :]
236- if len (self ._buffer [self ._length :]) == 0 :
237- print (f"Default buffer overflow, overflow size of { len (self ._buffer )} " )
238245 return self ._buffer [self ._length :]
239246
240247 def buffer_updated (self , nbytes : int ):
@@ -248,29 +255,31 @@ def buffer_updated(self, nbytes: int):
248255 if self ._overflow is not None :
249256 self ._overflow_length += nbytes
250257 else :
251- if self ._length == 0 :
258+ if self ._expecting_header :
252259 try :
253260 self ._body_length , self ._op_code = self .process_header ()
254261 except ProtocolError as exc :
255- if self ._debug :
256- print (f"Protocol error: { exc } " )
257262 self .connection_lost (exc )
258263 return
264+ self ._expecting_header = False
259265 if self ._body_length > self ._buffer_size :
260266 self ._overflow = memoryview (
261267 bytearray (self ._body_length - (self ._buffer_size - nbytes ) + 1000 )
262268 )
263269 self ._length += nbytes
264- if (
265- self ._length + self ._overflow_length >= self ._body_length
266- and self ._read_waiter
267- and not self ._read_waiter .done ()
268- ):
270+ if self ._length + self ._overflow_length >= self ._body_length and self ._pending_messages and not self ._pending_messages [0 ].done ():
271+ done = self ._pending_messages .popleft ()
272+ done .set_result ((self ._start , self ._body_length ))
273+ self ._done_messages .append (done )
269274 if self ._length > self ._body_length :
270- self ._body_length = self ._length
271- if self ._length + self ._overflow_length > self ._body_length :
272- print (f"Done reading with length { self ._length + self ._overflow_length } out of { self ._body_length } " )
273- self ._read_waiter .set_result (True )
275+ print ("Larger than expected length" )
276+ self ._read_waiter = asyncio .get_running_loop ().create_future ()
277+ self ._pending_messages .append (self ._read_waiter )
278+ self ._start = self ._body_length
279+ extra = self ._length - self ._body_length
280+ self ._length -= extra
281+ self ._expecting_header = True
282+ self .buffer_updated (extra )
274283
275284 def process_header (self ):
276285 """Unpack a MongoDB Wire Protocol header."""
@@ -312,11 +321,13 @@ def resume_writing(self):
312321
313322 def connection_lost (self , exc ):
314323 self ._connection_lost = True
315- if self ._read_waiter and not self ._read_waiter .done ():
324+ pending = [msg for msg in self ._pending_messages ]
325+ for msg in pending :
316326 if exc is None :
317- self . _read_waiter .set_result (None )
327+ msg .set_result (None )
318328 else :
319- self ._read_waiter .set_exception (exc )
329+ msg .set_exception (exc )
330+ self ._done_messages .append (msg )
320331
321332 if not self ._closed .done ():
322333 if exc is None :
@@ -441,12 +452,6 @@ async def async_receive_message(
441452 # timeouts on AWS Lambda and other FaaS environments.
442453 timeout = max (deadline - time .monotonic (), 0 )
443454
444- # if debug:
445- # print(f"async_receive_message with timeout: {timeout}. From csot: {_csot.get_timeout()}, from conn: {conn.conn.get_conn.gettimeout}, deadline: {deadline} ")
446- # if timeout is None:
447- # timeout = 5.0
448-
449-
450455 cancellation_task = create_task (_poll_cancellation (conn ))
451456 read_task = create_task (conn .conn .get_conn .read (request_id , max_message_size , debug ))
452457 tasks = [read_task , cancellation_task ]
0 commit comments