16
16
from __future__ import annotations
17
17
18
18
import asyncio
19
+ import collections
19
20
import errno
20
21
import socket
21
22
import struct
@@ -141,6 +142,7 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: Optional[int] =
141
142
self .transport = None
142
143
self ._buffer = memoryview (bytearray (self ._buffer_size ))
143
144
self ._overflow = None
145
+ self ._start = 0
144
146
self ._length = 0
145
147
self ._overflow_length = 0
146
148
self ._body_length = 0
@@ -157,7 +159,9 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: Optional[int] =
157
159
self ._request_id = None
158
160
self ._closed = asyncio .get_running_loop ().create_future ()
159
161
self ._debug = False
160
-
162
+ self ._expecting_header = True
163
+ self ._pending_messages = collections .deque ()
164
+ self ._done_messages = collections .deque ()
161
165
162
166
def settimeout (self , timeout : float | None ):
163
167
self ._timeout = timeout
@@ -182,24 +186,31 @@ async def write(self, message: bytes):
182
186
183
187
async def read (self , request_id : Optional [int ], max_message_size : int , debug : bool = False ):
184
188
"""Read a single MongoDB Wire Protocol message from this connection."""
185
- self ._debug = debug
186
- self ._max_message_size = max_message_size
187
- self ._request_id = request_id
188
- self ._length , self ._overflow_length , self ._body_length , self ._op_code , self ._overflow = (
189
- 0 ,
190
- 0 ,
191
- 0 ,
192
- None ,
193
- None ,
194
- )
195
- if self .transport .is_closing ():
196
- print ("Connection is closed" )
197
- raise OSError ("Connection is closed" )
198
- self ._read_waiter = asyncio .get_running_loop ().create_future ()
199
- await self ._read_waiter
200
- if self ._read_waiter .done () and self ._read_waiter .result ():
201
- if self ._debug :
202
- print ("Read waiter done" )
189
+ if self ._done_messages :
190
+ message = await self ._done_messages .popleft ()
191
+ else :
192
+ self ._expecting_header = True
193
+ self ._debug = debug
194
+ self ._max_message_size = max_message_size
195
+ self ._request_id = request_id
196
+ self ._length , self ._overflow_length , self ._body_length , self ._op_code , self ._overflow = (
197
+ 0 ,
198
+ 0 ,
199
+ 0 ,
200
+ None ,
201
+ None ,
202
+ )
203
+ if self .transport .is_closing ():
204
+ raise OSError ("Connection is closed" )
205
+ read_waiter = asyncio .get_running_loop ().create_future ()
206
+ self ._pending_messages .append (read_waiter )
207
+ try :
208
+ message = await read_waiter
209
+ finally :
210
+ if read_waiter in self ._done_messages :
211
+ self ._done_messages .remove (read_waiter )
212
+ if message :
213
+ start , end = message [0 ], message [1 ]
203
214
header_size = 16
204
215
if self ._body_length > self ._buffer_size :
205
216
if self ._is_compressed :
@@ -220,21 +231,17 @@ async def read(self, request_id: Optional[int], max_message_size: int, debug: bo
220
231
if self ._is_compressed :
221
232
header_size = 25
222
233
return decompress (
223
- memoryview (self ._buffer [header_size : self . _body_length ]),
234
+ memoryview (self ._buffer [start + header_size : end ]),
224
235
self ._compressor_id ,
225
236
), self ._op_code
226
237
else :
227
- return memoryview (self ._buffer [header_size : self . _body_length ]), self ._op_code
238
+ return memoryview (self ._buffer [start + header_size : end ]), self ._op_code
228
239
raise OSError ("connection closed" )
229
240
230
241
def get_buffer (self , sizehint : int ):
231
242
"""Called to allocate a new receive buffer."""
232
243
if self ._overflow is not None :
233
- if len (self ._overflow [self ._overflow_length :]) == 0 :
234
- print (f"Overflow buffer overflow, overflow size of { len (self ._overflow )} " )
235
244
return self ._overflow [self ._overflow_length :]
236
- if len (self ._buffer [self ._length :]) == 0 :
237
- print (f"Default buffer overflow, overflow size of { len (self ._buffer )} " )
238
245
return self ._buffer [self ._length :]
239
246
240
247
def buffer_updated (self , nbytes : int ):
@@ -248,29 +255,31 @@ def buffer_updated(self, nbytes: int):
248
255
if self ._overflow is not None :
249
256
self ._overflow_length += nbytes
250
257
else :
251
- if self ._length == 0 :
258
+ if self ._expecting_header :
252
259
try :
253
260
self ._body_length , self ._op_code = self .process_header ()
254
261
except ProtocolError as exc :
255
- if self ._debug :
256
- print (f"Protocol error: { exc } " )
257
262
self .connection_lost (exc )
258
263
return
264
+ self ._expecting_header = False
259
265
if self ._body_length > self ._buffer_size :
260
266
self ._overflow = memoryview (
261
267
bytearray (self ._body_length - (self ._buffer_size - nbytes ) + 1000 )
262
268
)
263
269
self ._length += nbytes
264
- if (
265
- self ._length + self ._overflow_length >= self ._body_length
266
- and self ._read_waiter
267
- and not self ._read_waiter .done ()
268
- ):
270
+ if self ._length + self ._overflow_length >= self ._body_length and self ._pending_messages and not self ._pending_messages [0 ].done ():
271
+ done = self ._pending_messages .popleft ()
272
+ done .set_result ((self ._start , self ._body_length ))
273
+ self ._done_messages .append (done )
269
274
if self ._length > self ._body_length :
270
- self ._body_length = self ._length
271
- if self ._length + self ._overflow_length > self ._body_length :
272
- print (f"Done reading with length { self ._length + self ._overflow_length } out of { self ._body_length } " )
273
- self ._read_waiter .set_result (True )
275
+ print ("Larger than expected length" )
276
+ self ._read_waiter = asyncio .get_running_loop ().create_future ()
277
+ self ._pending_messages .append (self ._read_waiter )
278
+ self ._start = self ._body_length
279
+ extra = self ._length - self ._body_length
280
+ self ._length -= extra
281
+ self ._expecting_header = True
282
+ self .buffer_updated (extra )
274
283
275
284
def process_header (self ):
276
285
"""Unpack a MongoDB Wire Protocol header."""
@@ -312,11 +321,13 @@ def resume_writing(self):
312
321
313
322
def connection_lost (self , exc ):
314
323
self ._connection_lost = True
315
- if self ._read_waiter and not self ._read_waiter .done ():
324
+ pending = [msg for msg in self ._pending_messages ]
325
+ for msg in pending :
316
326
if exc is None :
317
- self . _read_waiter .set_result (None )
327
+ msg .set_result (None )
318
328
else :
319
- self ._read_waiter .set_exception (exc )
329
+ msg .set_exception (exc )
330
+ self ._done_messages .append (msg )
320
331
321
332
if not self ._closed .done ():
322
333
if exc is None :
@@ -441,12 +452,6 @@ async def async_receive_message(
441
452
# timeouts on AWS Lambda and other FaaS environments.
442
453
timeout = max (deadline - time .monotonic (), 0 )
443
454
444
- # if debug:
445
- # print(f"async_receive_message with timeout: {timeout}. From csot: {_csot.get_timeout()}, from conn: {conn.conn.get_conn.gettimeout}, deadline: {deadline} ")
446
- # if timeout is None:
447
- # timeout = 5.0
448
-
449
-
450
455
cancellation_task = create_task (_poll_cancellation (conn ))
451
456
read_task = create_task (conn .conn .get_conn .read (request_id , max_message_size , debug ))
452
457
tasks = [read_task , cancellation_task ]
0 commit comments