Skip to content

Commit 682f15b

Browse files
authored
PYTHON-4618 - Fix TypeError: Socket cannot be of type SSLSocket (mongodb#1772)
1 parent 13cf110 commit 682f15b

File tree

4 files changed

+98
-11
lines changed

4 files changed

+98
-11
lines changed

pymongo/network_layer.py

Lines changed: 90 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,111 @@
1818
import asyncio
1919
import socket
2020
import struct
21+
import sys
22+
from asyncio import AbstractEventLoop, Future
2123
from typing import (
22-
TYPE_CHECKING,
2324
Union,
2425
)
2526

2627
from pymongo import ssl_support
2728

28-
if TYPE_CHECKING:
29-
from pymongo.pyopenssl_context import _sslConn
29+
try:
30+
from ssl import SSLError, SSLSocket
31+
32+
_HAVE_SSL = True
33+
except ImportError:
34+
_HAVE_SSL = False
35+
36+
try:
37+
from pymongo.pyopenssl_context import (
38+
BLOCKING_IO_LOOKUP_ERROR,
39+
BLOCKING_IO_READ_ERROR,
40+
BLOCKING_IO_WRITE_ERROR,
41+
_sslConn,
42+
)
43+
44+
_HAVE_PYOPENSSL = True
45+
except ImportError:
46+
_HAVE_PYOPENSSL = False
47+
_sslConn = SSLSocket # type: ignore
48+
from pymongo.ssl_support import ( # type: ignore[assignment]
49+
BLOCKING_IO_LOOKUP_ERROR,
50+
BLOCKING_IO_READ_ERROR,
51+
BLOCKING_IO_WRITE_ERROR,
52+
)
3053

3154
_UNPACK_HEADER = struct.Struct("<iiii").unpack
3255
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
3356
_POLL_TIMEOUT = 0.5
3457
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
35-
BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
58+
BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
3659

3760

38-
async def async_sendall(socket: Union[socket.socket, _sslConn], buf: bytes) -> None:
39-
timeout = socket.gettimeout()
40-
socket.settimeout(0.0)
61+
async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
62+
timeout = sock.gettimeout()
63+
sock.settimeout(0.0)
4164
loop = asyncio.get_event_loop()
4265
try:
43-
await asyncio.wait_for(loop.sock_sendall(socket, buf), timeout=timeout) # type: ignore[arg-type]
66+
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
67+
if sys.platform == "win32":
68+
await asyncio.wait_for(_async_sendall_ssl_windows(sock, buf), timeout=timeout)
69+
else:
70+
await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout)
71+
else:
72+
await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type]
4473
finally:
45-
socket.settimeout(timeout)
74+
sock.settimeout(timeout)
75+
76+
77+
async def _async_sendall_ssl(
78+
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
79+
) -> None:
80+
fd = sock.fileno()
81+
sent = 0
82+
83+
def _is_ready(fut: Future) -> None:
84+
loop.remove_writer(fd)
85+
loop.remove_reader(fd)
86+
if fut.done():
87+
return
88+
fut.set_result(None)
89+
90+
while sent < len(buf):
91+
try:
92+
sent += sock.send(buf)
93+
except BLOCKING_IO_ERRORS as exc:
94+
fd = sock.fileno()
95+
# Check for closed socket.
96+
if fd == -1:
97+
raise SSLError("Underlying socket has been closed") from None
98+
if isinstance(exc, BLOCKING_IO_READ_ERROR):
99+
fut = loop.create_future()
100+
loop.add_reader(fd, _is_ready, fut)
101+
await fut
102+
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
103+
fut = loop.create_future()
104+
loop.add_writer(fd, _is_ready, fut)
105+
await fut
106+
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
107+
fut = loop.create_future()
108+
loop.add_reader(fd, _is_ready, fut)
109+
loop.add_writer(fd, _is_ready, fut)
110+
await fut
111+
112+
113+
# The default Windows asyncio event loop does not support loop.add_reader/add_writer: https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
114+
async def _async_sendall_ssl_windows(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
115+
view = memoryview(buf)
116+
total_length = len(buf)
117+
total_sent = 0
118+
while total_sent < total_length:
119+
try:
120+
sent = sock.send(view[total_sent:])
121+
except BLOCKING_IO_ERRORS:
122+
await asyncio.sleep(0.5)
123+
sent = 0
124+
total_sent += sent
46125

47126

48-
def sendall(socket: Union[socket.socket, _sslConn], buf: bytes) -> None:
49-
socket.sendall(buf)
127+
def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
128+
sock.sendall(buf)

pymongo/pyopenssl_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def _is_ip_address(address: Any) -> bool:
9090
# According to the docs for socket.send it can raise
9191
# WantX509LookupError and should be retried.
9292
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
93+
BLOCKING_IO_READ_ERROR = _SSL.WantReadError
94+
BLOCKING_IO_WRITE_ERROR = _SSL.WantWriteError
95+
BLOCKING_IO_LOOKUP_ERROR = _SSL.WantX509LookupError
9396

9497

9598
def _ragged_eof(exc: BaseException) -> bool:

pymongo/ssl_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
# Errors raised by SSL sockets when in non-blocking mode.
3232
BLOCKING_IO_ERRORS = (_ssl.SSLWantReadError, _ssl.SSLWantWriteError)
33+
BLOCKING_IO_READ_ERROR = _ssl.SSLWantReadError
34+
BLOCKING_IO_WRITE_ERROR = _ssl.SSLWantWriteError
3335

3436
# Base Exception class
3537
SSLError = _ssl.SSLError

pymongo/ssl_support.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
IPADDR_SAFE = True
5454
SSLError = _ssl.SSLError
5555
BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS
56+
BLOCKING_IO_READ_ERROR = _ssl.BLOCKING_IO_READ_ERROR
57+
BLOCKING_IO_WRITE_ERROR = _ssl.BLOCKING_IO_WRITE_ERROR
58+
BLOCKING_IO_LOOKUP_ERROR = BLOCKING_IO_READ_ERROR
5659

5760
def get_ssl_context(
5861
certfile: Optional[str],

0 commit comments

Comments
 (0)