|
| 1 | +# Copyright 2025-present MongoDB, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Utility and helper methods for creating connections.""" |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +import asyncio |
| 19 | +import functools |
| 20 | +import socket |
| 21 | +import ssl |
| 22 | +from typing import ( |
| 23 | + TYPE_CHECKING, |
| 24 | + Union, |
| 25 | +) |
| 26 | + |
| 27 | +from pymongo import _csot |
| 28 | +from pymongo.asynchronous.helpers import _getaddrinfo |
| 29 | +from pymongo.errors import ( # type:ignore[attr-defined] |
| 30 | + ConnectionFailure, |
| 31 | + _CertificateError, |
| 32 | +) |
| 33 | +from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol |
| 34 | +from pymongo.pool_options import PoolOptions |
| 35 | +from pymongo.pool_shared import ( |
| 36 | + _get_timeout_details, |
| 37 | + _raise_connection_failure, |
| 38 | + _set_keepalive_times, |
| 39 | + _set_non_inheritable_non_atomic, |
| 40 | +) |
| 41 | +from pymongo.ssl_support import HAS_SNI, SSLError |
| 42 | + |
| 43 | +if TYPE_CHECKING: |
| 44 | + from pymongo.pyopenssl_context import _sslConn |
| 45 | + from pymongo.typings import _Address |
| 46 | + |
| 47 | +_IS_SYNC = False |
| 48 | + |
| 49 | + |
| 50 | +async def _async_create_connection(address: _Address, options: PoolOptions) -> socket.socket: |
| 51 | + """Given (host, port) and PoolOptions, connect and return a raw socket object. |
| 52 | +
|
| 53 | + Can raise socket.error. |
| 54 | +
|
| 55 | + This is a modified version of create_connection from CPython >= 2.7. |
| 56 | + """ |
| 57 | + host, port = address |
| 58 | + |
| 59 | + # Check if dealing with a unix domain socket |
| 60 | + if host.endswith(".sock"): |
| 61 | + if not hasattr(socket, "AF_UNIX"): |
| 62 | + raise ConnectionFailure("UNIX-sockets are not supported on this system") |
| 63 | + sock = socket.socket(socket.AF_UNIX) |
| 64 | + # SOCK_CLOEXEC not supported for Unix sockets. |
| 65 | + _set_non_inheritable_non_atomic(sock.fileno()) |
| 66 | + try: |
| 67 | + sock.connect(host) |
| 68 | + return sock |
| 69 | + except OSError: |
| 70 | + sock.close() |
| 71 | + raise |
| 72 | + |
| 73 | + # Don't try IPv6 if we don't support it. Also skip it if host |
| 74 | + # is 'localhost' (::1 is fine). Avoids slow connect issues |
| 75 | + # like PYTHON-356. |
| 76 | + family = socket.AF_INET |
| 77 | + if socket.has_ipv6 and host != "localhost": |
| 78 | + family = socket.AF_UNSPEC |
| 79 | + |
| 80 | + err = None |
| 81 | + for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] |
| 82 | + af, socktype, proto, dummy, sa = res |
| 83 | + # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited |
| 84 | + # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 |
| 85 | + # all file descriptors are created non-inheritable. See PEP 446. |
| 86 | + try: |
| 87 | + sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) |
| 88 | + except OSError: |
| 89 | + # Can SOCK_CLOEXEC be defined even if the kernel doesn't support |
| 90 | + # it? |
| 91 | + sock = socket.socket(af, socktype, proto) |
| 92 | + # Fallback when SOCK_CLOEXEC isn't available. |
| 93 | + _set_non_inheritable_non_atomic(sock.fileno()) |
| 94 | + try: |
| 95 | + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
| 96 | + # CSOT: apply timeout to socket connect. |
| 97 | + timeout = _csot.remaining() |
| 98 | + if timeout is None: |
| 99 | + timeout = options.connect_timeout |
| 100 | + elif timeout <= 0: |
| 101 | + raise socket.timeout("timed out") |
| 102 | + sock.settimeout(timeout) |
| 103 | + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) |
| 104 | + _set_keepalive_times(sock) |
| 105 | + sock.connect(sa) |
| 106 | + return sock |
| 107 | + except OSError as e: |
| 108 | + err = e |
| 109 | + sock.close() |
| 110 | + |
| 111 | + if err is not None: |
| 112 | + raise err |
| 113 | + else: |
| 114 | + # This likely means we tried to connect to an IPv6 only |
| 115 | + # host with an OS/kernel or Python interpreter that doesn't |
| 116 | + # support IPv6. The test case is Jython2.5.1 which doesn't |
| 117 | + # support IPv6 at all. |
| 118 | + raise OSError("getaddrinfo failed") |
| 119 | + |
| 120 | + |
| 121 | +async def _async_configured_socket( |
| 122 | + address: _Address, options: PoolOptions |
| 123 | +) -> Union[socket.socket, _sslConn]: |
| 124 | + """Given (host, port) and PoolOptions, return a raw configured socket. |
| 125 | +
|
| 126 | + Can raise socket.error, ConnectionFailure, or _CertificateError. |
| 127 | +
|
| 128 | + Sets socket's SSL and timeout options. |
| 129 | + """ |
| 130 | + sock = await _async_create_connection(address, options) |
| 131 | + ssl_context = options._ssl_context |
| 132 | + |
| 133 | + if ssl_context is None: |
| 134 | + sock.settimeout(options.socket_timeout) |
| 135 | + return sock |
| 136 | + |
| 137 | + host = address[0] |
| 138 | + try: |
| 139 | + # We have to pass hostname / ip address to wrap_socket |
| 140 | + # to use SSLContext.check_hostname. |
| 141 | + if HAS_SNI: |
| 142 | + if hasattr(ssl_context, "a_wrap_socket"): |
| 143 | + ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc] |
| 144 | + else: |
| 145 | + loop = asyncio.get_running_loop() |
| 146 | + ssl_sock = await loop.run_in_executor( |
| 147 | + None, |
| 148 | + functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc] |
| 149 | + ) |
| 150 | + else: |
| 151 | + if hasattr(ssl_context, "a_wrap_socket"): |
| 152 | + ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc] |
| 153 | + else: |
| 154 | + loop = asyncio.get_running_loop() |
| 155 | + ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc] |
| 156 | + except _CertificateError: |
| 157 | + sock.close() |
| 158 | + # Raise _CertificateError directly like we do after match_hostname |
| 159 | + # below. |
| 160 | + raise |
| 161 | + except (OSError, SSLError) as exc: |
| 162 | + sock.close() |
| 163 | + # We raise AutoReconnect for transient and permanent SSL handshake |
| 164 | + # failures alike. Permanent handshake failures, like protocol |
| 165 | + # mismatch, will be turned into ServerSelectionTimeoutErrors later. |
| 166 | + details = _get_timeout_details(options) |
| 167 | + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) |
| 168 | + if ( |
| 169 | + ssl_context.verify_mode |
| 170 | + and not ssl_context.check_hostname |
| 171 | + and not options.tls_allow_invalid_hostnames |
| 172 | + ): |
| 173 | + try: |
| 174 | + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined] |
| 175 | + except _CertificateError: |
| 176 | + ssl_sock.close() |
| 177 | + raise |
| 178 | + |
| 179 | + ssl_sock.settimeout(options.socket_timeout) |
| 180 | + return ssl_sock |
| 181 | + |
| 182 | + |
| 183 | +async def _configured_protocol_interface( |
| 184 | + address: _Address, options: PoolOptions |
| 185 | +) -> AsyncNetworkingInterface: |
| 186 | + """Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface. |
| 187 | +
|
| 188 | + Can raise socket.error, ConnectionFailure, or _CertificateError. |
| 189 | +
|
| 190 | + Sets protocol's SSL and timeout options. |
| 191 | + """ |
| 192 | + sock = await _async_create_connection(address, options) |
| 193 | + ssl_context = options._ssl_context |
| 194 | + timeout = options.socket_timeout |
| 195 | + |
| 196 | + if ssl_context is None: |
| 197 | + return AsyncNetworkingInterface( |
| 198 | + await asyncio.get_running_loop().create_connection( |
| 199 | + lambda: PyMongoProtocol(timeout=timeout, buffer_size=2**16), sock=sock |
| 200 | + ) |
| 201 | + ) |
| 202 | + |
| 203 | + host = address[0] |
| 204 | + try: |
| 205 | + # We have to pass hostname / ip address to wrap_socket |
| 206 | + # to use SSLContext.check_hostname. |
| 207 | + transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload] |
| 208 | + lambda: PyMongoProtocol(timeout=timeout, buffer_size=2**14), |
| 209 | + sock=sock, |
| 210 | + server_hostname=host, |
| 211 | + ssl=ssl_context, |
| 212 | + ) |
| 213 | + except _CertificateError: |
| 214 | + transport.abort() |
| 215 | + # Raise _CertificateError directly like we do after match_hostname |
| 216 | + # below. |
| 217 | + raise |
| 218 | + except (OSError, SSLError) as exc: |
| 219 | + transport.abort() |
| 220 | + # We raise AutoReconnect for transient and permanent SSL handshake |
| 221 | + # failures alike. Permanent handshake failures, like protocol |
| 222 | + # mismatch, will be turned into ServerSelectionTimeoutErrors later. |
| 223 | + details = _get_timeout_details(options) |
| 224 | + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) |
| 225 | + if ( |
| 226 | + ssl_context.verify_mode |
| 227 | + and not ssl_context.check_hostname |
| 228 | + and not options.tls_allow_invalid_hostnames |
| 229 | + ): |
| 230 | + try: |
| 231 | + ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore] |
| 232 | + except _CertificateError: |
| 233 | + transport.abort() |
| 234 | + raise |
| 235 | + |
| 236 | + return AsyncNetworkingInterface((transport, protocol)) |
| 237 | + |
| 238 | + |
| 239 | +if _IS_SYNC: |
| 240 | + from pymongo.synchronous.connection_helpers import _create_connection |
| 241 | + |
| 242 | + def _configured_socket_interface( |
| 243 | + address: _Address, options: PoolOptions |
| 244 | + ) -> NetworkingInterface: |
| 245 | + """Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket. |
| 246 | +
|
| 247 | + Can raise socket.error, ConnectionFailure, or _CertificateError. |
| 248 | +
|
| 249 | + Sets socket's SSL and timeout options. |
| 250 | + """ |
| 251 | + sock = _create_connection(address, options) |
| 252 | + ssl_context = options._ssl_context |
| 253 | + |
| 254 | + if ssl_context is None: |
| 255 | + sock.settimeout(options.socket_timeout) |
| 256 | + return NetworkingInterface(sock) |
| 257 | + |
| 258 | + host = address[0] |
| 259 | + try: |
| 260 | + # We have to pass hostname / ip address to wrap_socket |
| 261 | + # to use SSLContext.check_hostname. |
| 262 | + if HAS_SNI: |
| 263 | + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) |
| 264 | + else: |
| 265 | + ssl_sock = ssl_context.wrap_socket(sock) |
| 266 | + except _CertificateError: |
| 267 | + sock.close() |
| 268 | + # Raise _CertificateError directly like we do after match_hostname |
| 269 | + # below. |
| 270 | + raise |
| 271 | + except (OSError, SSLError) as exc: |
| 272 | + sock.close() |
| 273 | + # We raise AutoReconnect for transient and permanent SSL handshake |
| 274 | + # failures alike. Permanent handshake failures, like protocol |
| 275 | + # mismatch, will be turned into ServerSelectionTimeoutErrors later. |
| 276 | + details = _get_timeout_details(options) |
| 277 | + _raise_connection_failure( |
| 278 | + address, exc, "SSL handshake failed: ", timeout_details=details |
| 279 | + ) |
| 280 | + if ( |
| 281 | + ssl_context.verify_mode |
| 282 | + and not ssl_context.check_hostname |
| 283 | + and not options.tls_allow_invalid_hostnames |
| 284 | + ): |
| 285 | + try: |
| 286 | + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined,unused-ignore] |
| 287 | + except _CertificateError: |
| 288 | + ssl_sock.close() |
| 289 | + raise |
| 290 | + |
| 291 | + ssl_sock.settimeout(options.socket_timeout) |
| 292 | + return NetworkingInterface(ssl_sock) |
0 commit comments