26
26
Union ,
27
27
)
28
28
29
- from pymongo import _csot
29
+ from pymongo import _csot , ssl_support
30
30
from pymongo ._asyncio_task import create_task
31
31
from pymongo .common import MAX_MESSAGE_SIZE
32
32
from pymongo .compression_support import decompress
59
59
60
60
61
61
class NetworkingInterfaceBase :
62
- def __init__ (self , conn : Union [socket .socket , _sslConn ] | tuple [asyncio .BaseTransport , PyMongoProtocol ]):
62
+ def __init__ (
63
+ self , conn : Union [socket .socket , _sslConn ] | tuple [asyncio .BaseTransport , PyMongoProtocol ]
64
+ ):
63
65
self .conn = conn
64
66
65
67
def gettimeout (self ):
@@ -74,10 +76,7 @@ def close(self):
74
76
def is_closing (self ) -> bool :
75
77
raise NotImplementedError
76
78
77
- def writer (self ):
78
- raise NotImplementedError
79
-
80
- def reader (self ):
79
+ def get_conn (self ):
81
80
raise NotImplementedError
82
81
83
82
@@ -99,11 +98,7 @@ def is_closing(self):
99
98
self .conn [0 ].is_closing ()
100
99
101
100
@property
102
- def writer (self ) -> PyMongoProtocol :
103
- return self .conn [1 ]
104
-
105
- @property
106
- def reader (self ) -> PyMongoProtocol :
101
+ def get_conn (self ) -> PyMongoProtocol :
107
102
return self .conn [1 ]
108
103
109
104
@@ -124,11 +119,7 @@ def is_closing(self):
124
119
self .conn .is_closing ()
125
120
126
121
@property
127
- def writer (self ):
128
- return self .conn
129
-
130
- @property
131
- def reader (self ):
122
+ def get_conn (self ):
132
123
return self .conn
133
124
134
125
@@ -333,35 +324,8 @@ async def _poll_cancellation(conn: AsyncConnection) -> None:
333
324
await asyncio .sleep (_POLL_TIMEOUT )
334
325
335
326
336
- async def async_receive_data (
337
- conn : AsyncConnection ,
338
- deadline : Optional [float ],
339
- request_id : Optional [int ],
340
- max_message_size : int ,
341
- ) -> memoryview :
342
- conn_timeout = conn .conn .gettimeout
343
- timeout : Optional [Union [float , int ]]
344
- if deadline :
345
- # When the timeout has expired perform one final check to
346
- # see if the socket is readable. This helps avoid spurious
347
- # timeouts on AWS Lambda and other FaaS environments.
348
- timeout = max (deadline - time .monotonic (), 0 )
349
- else :
350
- timeout = conn_timeout
351
-
352
- cancellation_task = create_task (_poll_cancellation (conn ))
353
- read_task = create_task (conn .conn .reader .read (request_id , max_message_size ))
354
- tasks = [read_task , cancellation_task ]
355
- done , pending = await asyncio .wait (tasks , timeout = timeout , return_when = asyncio .FIRST_COMPLETED )
356
- for task in pending :
357
- task .cancel ()
358
- if pending :
359
- await asyncio .wait (pending )
360
- if len (done ) == 0 :
361
- raise socket .timeout ("timed out" )
362
- if read_task in done :
363
- return read_task .result ()
364
- raise _OperationCancelled ("operation cancelled" )
327
+ # Errors raised by sockets (and TLS sockets) when in non-blocking mode.
328
+ BLOCKING_IO_ERRORS = (BlockingIOError , * ssl_support .BLOCKING_IO_ERRORS )
365
329
366
330
367
331
def receive_data (conn : Connection , length : int , deadline : Optional [float ]) -> memoryview :
@@ -384,12 +348,12 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
384
348
short_timeout = _POLL_TIMEOUT
385
349
conn .set_conn_timeout (short_timeout )
386
350
try :
387
- chunk_length = conn .conn .recv_into (mv [bytes_read :])
388
- # except BLOCKING_IO_ERRORS:
389
- # if conn.cancel_context.cancelled:
390
- # raise _OperationCancelled("operation cancelled") from None
391
- # # We reached the true deadline.
392
- # raise socket.timeout("timed out") from None
351
+ chunk_length = conn .conn .get_conn . recv_into (mv [bytes_read :])
352
+ except BLOCKING_IO_ERRORS :
353
+ if conn .cancel_context .cancelled :
354
+ raise _OperationCancelled ("operation cancelled" ) from None
355
+ # We reached the true deadline.
356
+ raise socket .timeout ("timed out" ) from None
393
357
except socket .timeout :
394
358
if conn .cancel_context .cancelled :
395
359
raise _OperationCancelled ("operation cancelled" ) from None
@@ -416,22 +380,42 @@ async def async_receive_message(
416
380
max_message_size : int = MAX_MESSAGE_SIZE ,
417
381
) -> Union [_OpReply , _OpMsg ]:
418
382
"""Receive a raw BSON message or raise socket.error."""
383
+ timeout : Optional [Union [float , int ]]
419
384
if _csot .get_timeout ():
420
385
deadline = _csot .get_deadline ()
421
386
else :
422
- timeout = conn .conn .reader .gettimeout
387
+ timeout = conn .conn .get_conn .gettimeout
423
388
if timeout :
424
389
deadline = time .monotonic () + timeout
425
390
else :
426
391
deadline = None
427
- data , op_code = await async_receive_data (conn , deadline , request_id , max_message_size )
428
- try :
429
- unpack_reply = _UNPACK_REPLY [op_code ]
430
- except KeyError :
431
- raise ProtocolError (
432
- f"Got opcode { op_code !r} but expected { _UNPACK_REPLY .keys ()!r} "
433
- ) from None
434
- return unpack_reply (data )
392
+ if deadline :
393
+ # When the timeout has expired perform one final check to
394
+ # see if the socket is readable. This helps avoid spurious
395
+ # timeouts on AWS Lambda and other FaaS environments.
396
+ timeout = max (deadline - time .monotonic (), 0 )
397
+
398
+ cancellation_task = create_task (_poll_cancellation (conn ))
399
+ read_task = create_task (conn .conn .get_conn .read (request_id , max_message_size ))
400
+ tasks = [read_task , cancellation_task ]
401
+ done , pending = await asyncio .wait (tasks , timeout = timeout , return_when = asyncio .FIRST_COMPLETED )
402
+ for task in pending :
403
+ task .cancel ()
404
+ if pending :
405
+ await asyncio .wait (pending )
406
+ if len (done ) == 0 :
407
+ raise socket .timeout ("timed out" )
408
+ if read_task in done :
409
+ data , op_code = read_task .result ()
410
+
411
+ try :
412
+ unpack_reply = _UNPACK_REPLY [op_code ]
413
+ except KeyError :
414
+ raise ProtocolError (
415
+ f"Got opcode { op_code !r} but expected { _UNPACK_REPLY .keys ()!r} "
416
+ ) from None
417
+ return unpack_reply (data )
418
+ raise _OperationCancelled ("operation cancelled" )
435
419
436
420
437
421
def receive_message (
0 commit comments