Skip to content

Commit 9aa51c5

Browse files
committed
Fix connection helpers
1 parent 850ff33 commit 9aa51c5

File tree

8 files changed

+619
-337
lines changed

8 files changed

+619
-337
lines changed
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# Copyright 2025-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility and helper methods for creating connections."""
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import functools
20+
import socket
21+
import ssl
22+
from typing import (
23+
TYPE_CHECKING,
24+
Union,
25+
)
26+
27+
from pymongo import _csot
28+
from pymongo.asynchronous.helpers import _getaddrinfo
29+
from pymongo.errors import ( # type:ignore[attr-defined]
30+
ConnectionFailure,
31+
_CertificateError,
32+
)
33+
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
34+
from pymongo.pool_options import PoolOptions
35+
from pymongo.pool_shared import (
36+
_get_timeout_details,
37+
_raise_connection_failure,
38+
_set_keepalive_times,
39+
_set_non_inheritable_non_atomic,
40+
)
41+
from pymongo.ssl_support import HAS_SNI, SSLError
42+
43+
if TYPE_CHECKING:
44+
from pymongo.pyopenssl_context import _sslConn
45+
from pymongo.typings import _Address
46+
47+
_IS_SYNC = False
48+
49+
50+
async def _async_create_connection(address: _Address, options: PoolOptions) -> socket.socket:
51+
"""Given (host, port) and PoolOptions, connect and return a raw socket object.
52+
53+
Can raise socket.error.
54+
55+
This is a modified version of create_connection from CPython >= 2.7.
56+
"""
57+
host, port = address
58+
59+
# Check if dealing with a unix domain socket
60+
if host.endswith(".sock"):
61+
if not hasattr(socket, "AF_UNIX"):
62+
raise ConnectionFailure("UNIX-sockets are not supported on this system")
63+
sock = socket.socket(socket.AF_UNIX)
64+
# SOCK_CLOEXEC not supported for Unix sockets.
65+
_set_non_inheritable_non_atomic(sock.fileno())
66+
try:
67+
sock.connect(host)
68+
return sock
69+
except OSError:
70+
sock.close()
71+
raise
72+
73+
# Don't try IPv6 if we don't support it. Also skip it if host
74+
# is 'localhost' (::1 is fine). Avoids slow connect issues
75+
# like PYTHON-356.
76+
family = socket.AF_INET
77+
if socket.has_ipv6 and host != "localhost":
78+
family = socket.AF_UNSPEC
79+
80+
err = None
81+
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
82+
af, socktype, proto, dummy, sa = res
83+
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
84+
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
85+
# all file descriptors are created non-inheritable. See PEP 446.
86+
try:
87+
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
88+
except OSError:
89+
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
90+
# it?
91+
sock = socket.socket(af, socktype, proto)
92+
# Fallback when SOCK_CLOEXEC isn't available.
93+
_set_non_inheritable_non_atomic(sock.fileno())
94+
try:
95+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
96+
# CSOT: apply timeout to socket connect.
97+
timeout = _csot.remaining()
98+
if timeout is None:
99+
timeout = options.connect_timeout
100+
elif timeout <= 0:
101+
raise socket.timeout("timed out")
102+
sock.settimeout(timeout)
103+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
104+
_set_keepalive_times(sock)
105+
sock.connect(sa)
106+
return sock
107+
except OSError as e:
108+
err = e
109+
sock.close()
110+
111+
if err is not None:
112+
raise err
113+
else:
114+
# This likely means we tried to connect to an IPv6 only
115+
# host with an OS/kernel or Python interpreter that doesn't
116+
# support IPv6. The test case is Jython2.5.1 which doesn't
117+
# support IPv6 at all.
118+
raise OSError("getaddrinfo failed")
119+
120+
121+
async def _async_configured_socket(
122+
address: _Address, options: PoolOptions
123+
) -> Union[socket.socket, _sslConn]:
124+
"""Given (host, port) and PoolOptions, return a raw configured socket.
125+
126+
Can raise socket.error, ConnectionFailure, or _CertificateError.
127+
128+
Sets socket's SSL and timeout options.
129+
"""
130+
sock = await _async_create_connection(address, options)
131+
ssl_context = options._ssl_context
132+
133+
if ssl_context is None:
134+
sock.settimeout(options.socket_timeout)
135+
return sock
136+
137+
host = address[0]
138+
try:
139+
# We have to pass hostname / ip address to wrap_socket
140+
# to use SSLContext.check_hostname.
141+
if HAS_SNI:
142+
if hasattr(ssl_context, "a_wrap_socket"):
143+
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
144+
else:
145+
loop = asyncio.get_running_loop()
146+
ssl_sock = await loop.run_in_executor(
147+
None,
148+
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
149+
)
150+
else:
151+
if hasattr(ssl_context, "a_wrap_socket"):
152+
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
153+
else:
154+
loop = asyncio.get_running_loop()
155+
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
156+
except _CertificateError:
157+
sock.close()
158+
# Raise _CertificateError directly like we do after match_hostname
159+
# below.
160+
raise
161+
except (OSError, SSLError) as exc:
162+
sock.close()
163+
# We raise AutoReconnect for transient and permanent SSL handshake
164+
# failures alike. Permanent handshake failures, like protocol
165+
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
166+
details = _get_timeout_details(options)
167+
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
168+
if (
169+
ssl_context.verify_mode
170+
and not ssl_context.check_hostname
171+
and not options.tls_allow_invalid_hostnames
172+
):
173+
try:
174+
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
175+
except _CertificateError:
176+
ssl_sock.close()
177+
raise
178+
179+
ssl_sock.settimeout(options.socket_timeout)
180+
return ssl_sock
181+
182+
183+
async def _configured_protocol_interface(
184+
address: _Address, options: PoolOptions
185+
) -> AsyncNetworkingInterface:
186+
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
187+
188+
Can raise socket.error, ConnectionFailure, or _CertificateError.
189+
190+
Sets protocol's SSL and timeout options.
191+
"""
192+
sock = await _async_create_connection(address, options)
193+
ssl_context = options._ssl_context
194+
timeout = options.socket_timeout
195+
196+
if ssl_context is None:
197+
return AsyncNetworkingInterface(
198+
await asyncio.get_running_loop().create_connection(
199+
lambda: PyMongoProtocol(timeout=timeout, buffer_size=2**16), sock=sock
200+
)
201+
)
202+
203+
host = address[0]
204+
try:
205+
# We have to pass hostname / ip address to wrap_socket
206+
# to use SSLContext.check_hostname.
207+
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
208+
lambda: PyMongoProtocol(timeout=timeout, buffer_size=2**14),
209+
sock=sock,
210+
server_hostname=host,
211+
ssl=ssl_context,
212+
)
213+
except _CertificateError:
214+
transport.abort()
215+
# Raise _CertificateError directly like we do after match_hostname
216+
# below.
217+
raise
218+
except (OSError, SSLError) as exc:
219+
transport.abort()
220+
# We raise AutoReconnect for transient and permanent SSL handshake
221+
# failures alike. Permanent handshake failures, like protocol
222+
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
223+
details = _get_timeout_details(options)
224+
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
225+
if (
226+
ssl_context.verify_mode
227+
and not ssl_context.check_hostname
228+
and not options.tls_allow_invalid_hostnames
229+
):
230+
try:
231+
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
232+
except _CertificateError:
233+
transport.abort()
234+
raise
235+
236+
return AsyncNetworkingInterface((transport, protocol))
237+
238+
239+
if _IS_SYNC:
240+
from pymongo.synchronous.connection_helpers import _create_connection
241+
242+
def _configured_socket_interface(
243+
address: _Address, options: PoolOptions
244+
) -> NetworkingInterface:
245+
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.
246+
247+
Can raise socket.error, ConnectionFailure, or _CertificateError.
248+
249+
Sets socket's SSL and timeout options.
250+
"""
251+
sock = _create_connection(address, options)
252+
ssl_context = options._ssl_context
253+
254+
if ssl_context is None:
255+
sock.settimeout(options.socket_timeout)
256+
return NetworkingInterface(sock)
257+
258+
host = address[0]
259+
try:
260+
# We have to pass hostname / ip address to wrap_socket
261+
# to use SSLContext.check_hostname.
262+
if HAS_SNI:
263+
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
264+
else:
265+
ssl_sock = ssl_context.wrap_socket(sock)
266+
except _CertificateError:
267+
sock.close()
268+
# Raise _CertificateError directly like we do after match_hostname
269+
# below.
270+
raise
271+
except (OSError, SSLError) as exc:
272+
sock.close()
273+
# We raise AutoReconnect for transient and permanent SSL handshake
274+
# failures alike. Permanent handshake failures, like protocol
275+
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
276+
details = _get_timeout_details(options)
277+
_raise_connection_failure(
278+
address, exc, "SSL handshake failed: ", timeout_details=details
279+
)
280+
if (
281+
ssl_context.verify_mode
282+
and not ssl_context.check_hostname
283+
and not options.tls_allow_invalid_hostnames
284+
):
285+
try:
286+
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined,unused-ignore]
287+
except _CertificateError:
288+
ssl_sock.close()
289+
raise
290+
291+
ssl_sock.settimeout(options.socket_timeout)
292+
return NetworkingInterface(ssl_sock)

pymongo/asynchronous/encryption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
cast,
3939
)
4040

41-
from pymongo.pool_shared import _async_configured_socket
41+
from pymongo.asynchronous.connection_helpers import _async_configured_socket
4242

4343
try:
4444
from pymongocrypt.asynchronous.auto_encrypter import AsyncAutoEncrypter # type:ignore[import]

pymongo/asynchronous/pool.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from bson import DEFAULT_CODEC_OPTIONS
3737
from pymongo import _csot, helpers_shared
3838
from pymongo.asynchronous.client_session import _validate_session_write_concern
39+
from pymongo.asynchronous.connection_helpers import _configured_protocol_interface
3940
from pymongo.asynchronous.helpers import _handle_reauth
4041
from pymongo.asynchronous.network import command
4142
from pymongo.common import (
@@ -76,7 +77,6 @@
7677
from pymongo.pool_options import PoolOptions
7778
from pymongo.pool_shared import (
7879
_CancellationContext,
79-
_configured_protocol,
8080
_get_timeout_details,
8181
_raise_connection_failure,
8282
format_timeout_details,
@@ -1008,7 +1008,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
10081008
)
10091009

10101010
try:
1011-
networking_interface = await _configured_protocol(self.address, self.opts)
1011+
networking_interface = await _configured_protocol_interface(self.address, self.opts)
10121012
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
10131013
except BaseException as error:
10141014
async with self.lock:
@@ -1439,9 +1439,10 @@ def _raise_wait_queue_timeout(self, checkout_started_time: float) -> NoReturn:
14391439
f"maxPoolSize: {self.opts.max_pool_size}, timeout: {timeout}"
14401440
)
14411441

1442-
# def __del__(self) -> None:
1443-
# # Avoid ResourceWarnings in Python 3
1444-
# # Close all sockets without calling reset() or close() because it is
1445-
# # not safe to acquire a lock in __del__.
1446-
# for conn in self.conns:
1447-
# conn.close_conn(None)
1442+
def __del__(self) -> None:
1443+
# Avoid ResourceWarnings in Python 3
1444+
# Close all sockets without calling reset() or close() because it is
1445+
# not safe to acquire a lock in __del__.
1446+
if _IS_SYNC:
1447+
for conn in self.conns:
1448+
conn.close_conn(None)

0 commit comments

Comments
 (0)