2626 Union ,
2727)
2828
29- from pymongo import _csot
29+ from pymongo import _csot , ssl_support
3030from pymongo ._asyncio_task import create_task
3131from pymongo .common import MAX_MESSAGE_SIZE
3232from pymongo .compression_support import decompress
5959
6060
6161class NetworkingInterfaceBase :
62- def __init__ (self , conn : Union [socket .socket , _sslConn ] | tuple [asyncio .BaseTransport , PyMongoProtocol ]):
62+ def __init__ (
63+ self , conn : Union [socket .socket , _sslConn ] | tuple [asyncio .BaseTransport , PyMongoProtocol ]
64+ ):
6365 self .conn = conn
6466
6567 def gettimeout (self ):
@@ -74,10 +76,7 @@ def close(self):
7476 def is_closing (self ) -> bool :
7577 raise NotImplementedError
7678
77- def writer (self ):
78- raise NotImplementedError
79-
80- def reader (self ):
79+ def get_conn (self ):
8180 raise NotImplementedError
8281
8382
@@ -99,11 +98,7 @@ def is_closing(self):
9998 self .conn [0 ].is_closing ()
10099
101100 @property
102- def writer (self ) -> PyMongoProtocol :
103- return self .conn [1 ]
104-
105- @property
106- def reader (self ) -> PyMongoProtocol :
101+ def get_conn (self ) -> PyMongoProtocol :
107102 return self .conn [1 ]
108103
109104
@@ -124,11 +119,7 @@ def is_closing(self):
124119 self .conn .is_closing ()
125120
126121 @property
127- def writer (self ):
128- return self .conn
129-
130- @property
131- def reader (self ):
122+ def get_conn (self ):
132123 return self .conn
133124
134125
@@ -333,35 +324,8 @@ async def _poll_cancellation(conn: AsyncConnection) -> None:
333324 await asyncio .sleep (_POLL_TIMEOUT )
334325
335326
336- async def async_receive_data (
337- conn : AsyncConnection ,
338- deadline : Optional [float ],
339- request_id : Optional [int ],
340- max_message_size : int ,
341- ) -> memoryview :
342- conn_timeout = conn .conn .gettimeout
343- timeout : Optional [Union [float , int ]]
344- if deadline :
345- # When the timeout has expired perform one final check to
346- # see if the socket is readable. This helps avoid spurious
347- # timeouts on AWS Lambda and other FaaS environments.
348- timeout = max (deadline - time .monotonic (), 0 )
349- else :
350- timeout = conn_timeout
351-
352- cancellation_task = create_task (_poll_cancellation (conn ))
353- read_task = create_task (conn .conn .reader .read (request_id , max_message_size ))
354- tasks = [read_task , cancellation_task ]
355- done , pending = await asyncio .wait (tasks , timeout = timeout , return_when = asyncio .FIRST_COMPLETED )
356- for task in pending :
357- task .cancel ()
358- if pending :
359- await asyncio .wait (pending )
360- if len (done ) == 0 :
361- raise socket .timeout ("timed out" )
362- if read_task in done :
363- return read_task .result ()
364- raise _OperationCancelled ("operation cancelled" )
327+ # Errors raised by sockets (and TLS sockets) when in non-blocking mode.
328+ BLOCKING_IO_ERRORS = (BlockingIOError , * ssl_support .BLOCKING_IO_ERRORS )
365329
366330
367331def receive_data (conn : Connection , length : int , deadline : Optional [float ]) -> memoryview :
@@ -384,12 +348,12 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
384348 short_timeout = _POLL_TIMEOUT
385349 conn .set_conn_timeout (short_timeout )
386350 try :
387- chunk_length = conn .conn .recv_into (mv [bytes_read :])
388- # except BLOCKING_IO_ERRORS:
389- # if conn.cancel_context.cancelled:
390- # raise _OperationCancelled("operation cancelled") from None
391- # # We reached the true deadline.
392- # raise socket.timeout("timed out") from None
351+ chunk_length = conn .conn .get_conn . recv_into (mv [bytes_read :])
352+ except BLOCKING_IO_ERRORS :
353+ if conn .cancel_context .cancelled :
354+ raise _OperationCancelled ("operation cancelled" ) from None
355+ # We reached the true deadline.
356+ raise socket .timeout ("timed out" ) from None
393357 except socket .timeout :
394358 if conn .cancel_context .cancelled :
395359 raise _OperationCancelled ("operation cancelled" ) from None
@@ -416,22 +380,42 @@ async def async_receive_message(
416380 max_message_size : int = MAX_MESSAGE_SIZE ,
417381) -> Union [_OpReply , _OpMsg ]:
418382 """Receive a raw BSON message or raise socket.error."""
383+ timeout : Optional [Union [float , int ]]
419384 if _csot .get_timeout ():
420385 deadline = _csot .get_deadline ()
421386 else :
422- timeout = conn .conn .reader .gettimeout
387+ timeout = conn .conn .get_conn .gettimeout
423388 if timeout :
424389 deadline = time .monotonic () + timeout
425390 else :
426391 deadline = None
427- data , op_code = await async_receive_data (conn , deadline , request_id , max_message_size )
428- try :
429- unpack_reply = _UNPACK_REPLY [op_code ]
430- except KeyError :
431- raise ProtocolError (
432- f"Got opcode { op_code !r} but expected { _UNPACK_REPLY .keys ()!r} "
433- ) from None
434- return unpack_reply (data )
392+ if deadline :
393+ # When the timeout has expired perform one final check to
394+ # see if the socket is readable. This helps avoid spurious
395+ # timeouts on AWS Lambda and other FaaS environments.
396+ timeout = max (deadline - time .monotonic (), 0 )
397+
398+ cancellation_task = create_task (_poll_cancellation (conn ))
399+ read_task = create_task (conn .conn .get_conn .read (request_id , max_message_size ))
400+ tasks = [read_task , cancellation_task ]
401+ done , pending = await asyncio .wait (tasks , timeout = timeout , return_when = asyncio .FIRST_COMPLETED )
402+ for task in pending :
403+ task .cancel ()
404+ if pending :
405+ await asyncio .wait (pending )
406+ if len (done ) == 0 :
407+ raise socket .timeout ("timed out" )
408+ if read_task in done :
409+ data , op_code = read_task .result ()
410+
411+ try :
412+ unpack_reply = _UNPACK_REPLY [op_code ]
413+ except KeyError :
414+ raise ProtocolError (
415+ f"Got opcode { op_code !r} but expected { _UNPACK_REPLY .keys ()!r} "
416+ ) from None
417+ return unpack_reply (data )
418+ raise _OperationCancelled ("operation cancelled" )
435419
436420
437421def receive_message (
0 commit comments