diff --git a/cheroot/ssl/__init__.py b/cheroot/ssl/__init__.py index db6e8c24c3..c0072c0557 100644 --- a/cheroot/ssl/__init__.py +++ b/cheroot/ssl/__init__.py @@ -20,12 +20,15 @@ def __init__( private_key, certificate_chain=None, ciphers=None, + *, + private_key_password=None, ): - """Set up certificates, private key ciphers and reset context.""" + """Set up certificates, private key, ciphers and reset context.""" self.certificate = certificate self.private_key = private_key self.certificate_chain = certificate_chain self.ciphers = ciphers + self.private_key_password = private_key_password self.context = None @abstractmethod diff --git a/cheroot/ssl/__init__.pyi b/cheroot/ssl/__init__.pyi index 19a0a2e4da..c595121546 100644 --- a/cheroot/ssl/__init__.pyi +++ b/cheroot/ssl/__init__.pyi @@ -6,6 +6,7 @@ class Adapter(ABC): private_key: Any certificate_chain: Any ciphers: Any + private_key_password: str | bytes | None context: Any @abstractmethod def __init__( @@ -14,6 +15,8 @@ class Adapter(ABC): private_key, certificate_chain: Any | None = ..., ciphers: Any | None = ..., + *, + private_key_password: str | bytes | None = ..., ): ... @abstractmethod def bind(self, sock): ... diff --git a/cheroot/ssl/builtin.py b/cheroot/ssl/builtin.py index 1cec20255d..ed747ab6e3 100644 --- a/cheroot/ssl/builtin.py +++ b/cheroot/ssl/builtin.py @@ -76,10 +76,20 @@ def _loopback_for_cert_thread(context, server): ssl_sock.send(b'0000') -def _loopback_for_cert(certificate, private_key, certificate_chain): +def _loopback_for_cert( + certificate, + private_key, + certificate_chain, + *, + private_key_password=None, +): """Create a loopback connection to parse a cert with a private key.""" context = ssl.create_default_context(cafile=certificate_chain) - context.load_cert_chain(certificate, private_key) + context.load_cert_chain( + certificate, + private_key, + password=private_key_password, + ) context.check_hostname = False context.verify_mode = ssl.CERT_NONE @@ -112,7 +122,13 @@ def _loopback_for_cert(certificate, private_key, certificate_chain): server.close() -def _parse_cert(certificate, private_key, certificate_chain): +def _parse_cert( + certificate, + private_key, + certificate_chain, + *, + private_key_password=None, +): """Parse a certificate.""" # loopback_for_cert uses socket.socketpair which was only # introduced in Python 3.0 for *nix and 3.5 for Windows @@ -120,7 +136,12 @@ def _parse_cert(certificate, private_key, certificate_chain): # it also requires a private key either in its own file # or combined with the cert (SSLError) with suppress(AttributeError, ssl.SSLError, OSError): - return _loopback_for_cert(certificate, private_key, certificate_chain) + return _loopback_for_cert( + certificate, + private_key, + certificate_chain, + private_key_password=private_key_password, + ) # KLUDGE: using an undocumented, private, test method to parse a cert # unfortunately, it is the only built-in way without a connection @@ -153,6 +174,9 @@ class BuiltinSSLAdapter(Adapter): ciphers = None """The ciphers list of SSL.""" + private_key_password = None + """Optional passphrase for password protected private key.""" + # from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert CERT_KEY_TO_ENV = { 'version': 'M_VERSION', @@ -208,6 +232,8 @@ def __init__( private_key, certificate_chain=None, ciphers=None, + *, + private_key_password=None, ): """Set up context in addition to base class properties if available.""" if ssl is None: @@ -218,19 +244,29 @@ def __init__( private_key, certificate_chain, ciphers, + private_key_password=private_key_password, ) self.context = ssl.create_default_context( purpose=ssl.Purpose.CLIENT_AUTH, cafile=certificate_chain, ) - self.context.load_cert_chain(certificate, private_key) + self.context.load_cert_chain( + certificate, + private_key, + password=private_key_password, + ) if self.ciphers is not None: self.context.set_ciphers(ciphers) self._server_env = self._make_env_cert_dict( 'SSL_SERVER', - _parse_cert(certificate, private_key, self.certificate_chain), + _parse_cert( + certificate, + private_key, + self.certificate_chain, + private_key_password=private_key_password, + ), ) if not self._server_env: return diff --git a/cheroot/ssl/builtin.pyi b/cheroot/ssl/builtin.pyi index ca106aca12..b05aaf5ad7 100644 --- a/cheroot/ssl/builtin.pyi +++ b/cheroot/ssl/builtin.pyi @@ -13,6 +13,8 @@ class BuiltinSSLAdapter(Adapter): private_key, certificate_chain: Any | None = ..., ciphers: Any | None = ..., + *, + private_key_password: str | bytes | None = ..., ) -> None: ... @property def context(self): ... diff --git a/cheroot/ssl/pyopenssl.py b/cheroot/ssl/pyopenssl.py index d17db00468..e5aedb585c 100644 --- a/cheroot/ssl/pyopenssl.py +++ b/cheroot/ssl/pyopenssl.py @@ -54,6 +54,7 @@ import sys import threading import time +from warnings import warn as _warn try: @@ -293,12 +294,17 @@ class pyOpenSSLAdapter(Adapter): ciphers = None """The ciphers list of TLS.""" + private_key_password = None + """Optional passphrase for password protected private key.""" + def __init__( self, certificate, private_key, certificate_chain=None, ciphers=None, + *, + private_key_password=None, ): """Initialize OpenSSL Adapter instance.""" if SSL is None: @@ -309,6 +315,7 @@ def __init__( private_key, certificate_chain, ciphers, + private_key_password=private_key_password, ) self._environ = None @@ -328,6 +335,31 @@ def wrap(self, sock): # closing so we can't reliably access protocol/client cert for the env return sock, self._environ.copy() + def _password_callback( + self, + password_max_length, + _verify_twice, + password, + /, + ): + """Pass a passphrase to password protected private key.""" + b_password = b'' + if isinstance(password, str): + b_password = password.encode('utf-8') + elif isinstance(password, bytes): + b_password = password + + password_length = len(b_password) + if password_length > password_max_length: + _warn( + f'User-provided password is {password_length} bytes long and will ' + f'be truncated since it exceeds the maximum of {password_max_length}.', + UserWarning, + stacklevel=1, + ) + + return b_password + def get_context(self): """Return an ``SSL.Context`` from self attributes. @@ -335,6 +367,7 @@ def get_context(self): """ # See https://code.activestate.com/recipes/442473/ c = SSL.Context(SSL.SSLv23_METHOD) + c.set_passwd_cb(self._password_callback, self.private_key_password) c.use_privatekey_file(self.private_key) if self.certificate_chain: c.load_verify_locations(self.certificate_chain) diff --git a/cheroot/ssl/pyopenssl.pyi b/cheroot/ssl/pyopenssl.pyi index 1409fffcb7..59dae05ae7 100644 --- a/cheroot/ssl/pyopenssl.pyi +++ b/cheroot/ssl/pyopenssl.pyi @@ -31,9 +31,18 @@ class pyOpenSSLAdapter(Adapter): private_key, certificate_chain: Any | None = ..., ciphers: Any | None = ..., + *, + private_key_password: str | bytes | None = ..., ) -> None: ... def bind(self, sock): ... def wrap(self, sock): ... + def _password_callback( + self, + password_max_length: int, + _verify_twice: bool, + password: bytes | str | None, + /, + ) -> bytes: ... def get_environ(self): ... def makefile(self, sock, mode: str = ..., bufsize: int = ...): ... def get_context(self) -> SSL.Context: ... diff --git a/cheroot/test/test_ssl.py b/cheroot/test/test_ssl.py index 7fbbe2b229..e97460e529 100644 --- a/cheroot/test/test_ssl.py +++ b/cheroot/test/test_ssl.py @@ -1,5 +1,6 @@ """Tests for TLS support.""" +import contextlib import functools import http.client import json @@ -16,6 +17,13 @@ import OpenSSL.SSL import requests import trustme +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import ( + BestAvailableEncryption, + Encoding, + PrivateFormat, + load_pem_private_key, +) from .._compat import ( IS_ABOVE_OPENSSL10, @@ -31,7 +39,7 @@ ntob, ntou, ) -from ..server import HTTPServer, get_ssl_adapter_class +from ..server import Gateway, HTTPServer, get_ssl_adapter_class from ..testing import ( ANY_INTERFACE_IPV4, ANY_INTERFACE_IPV6, @@ -138,6 +146,12 @@ def make_tls_http_server(bind_addr, ssl_adapter, request): return httpserver +@pytest.fixture(scope='session') +def private_key_password(): + """Provide hardcoded password for private key.""" + return 'криївка' + + @pytest.fixture def tls_http_server(request): """Provision a server creator as a fixture.""" @@ -178,6 +192,34 @@ def tls_certificate_private_key_pem_path(tls_certificate): yield cert_key_pem +@pytest.fixture +def tls_certificate_passwd_private_key_pem_path( + tls_certificate, + private_key_password, + tmp_path, +): + """Provide a certificate private key PEM file path via fixture.""" + key_as_bytes = tls_certificate.private_key_pem.bytes() + private_key_object = load_pem_private_key( + key_as_bytes, + password=None, + backend=default_backend(), + ) + + encrypted_key_as_bytes = private_key_object.private_bytes( + encoding=Encoding.PEM, + format=PrivateFormat.PKCS8, + encryption_algorithm=BestAvailableEncryption( + password=private_key_password.encode('utf-8'), + ), + ) + + key_file = tmp_path / 'key.pem' + key_file.write_bytes(encrypted_key_as_bytes) + + return key_file + + def _thread_except_hook(exceptions, args): """Append uncaught exception ``args`` in threads to ``exceptions``.""" if issubclass(args.exc_type, SystemExit): @@ -757,3 +799,142 @@ def test_http_over_https_error( 'The underlying error is {underlying_error!r}'.format(**locals()) ) assert expected_error_text in err_text + + +@pytest.mark.parametrize( + ('adapter_type', 'encrypted_key', 'password_as_bytes'), + ( + ('builtin', True, True), + ('builtin', True, False), + ('builtin', False, True), + ('builtin', False, False), + ('pyopenssl', True, True), + ('pyopenssl', True, False), + ('pyopenssl', False, True), + ('pyopenssl', False, False), + ), + ids=( + 'builtin-encrypted-key-with-str-password', + 'builtin-encrypted-key-with-bytes-password', + 'builtin-non-encrypted-key-with-str-password', + 'builtin-non-encrypted-key-with-bytes-password', + 'pyopenssl-encrypted-key-with-str-password', + 'pyopenssl-encrypted-key-with-bytes-password', + 'pyopenssl-non-encrypted-key-with-str-password', + 'pyopenssl-non-encrypted-key-with-bytes-password', + ), +) +# pylint: disable-next=too-many-positional-arguments +def test_ssl_adapters_with_private_key_password( + private_key_password, + tls_certificate_chain_pem_path, + tls_certificate_passwd_private_key_pem_path, + tls_certificate_private_key_pem_path, + adapter_type, + encrypted_key, + password_as_bytes, +): + """Check that server starts using ssl adapter with password-protected private key.""" + httpserver = HTTPServer( + bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), + gateway=Gateway, + ) + + key_file = ( + tls_certificate_passwd_private_key_pem_path + if encrypted_key + else tls_certificate_private_key_pem_path + ) + key_pass = ( + private_key_password.encode('utf-8') + if password_as_bytes + else private_key_password + ) + + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) + tls_adapter = tls_adapter_cls( + certificate=tls_certificate_chain_pem_path, + private_key=key_file, + private_key_password=key_pass, + ) + + httpserver.ssl_adapter = tls_adapter + httpserver.prepare() + + assert httpserver.ready + assert httpserver.requests._threads + for thr in httpserver.requests._threads: + assert thr.ready + + httpserver.stop() + + +@pytest.mark.parametrize( + 'adapter_type', + ('builtin',), +) +def test_builtin_adapter_with_false_key_password( + tls_certificate_chain_pem_path, + tls_certificate_passwd_private_key_pem_path, + adapter_type, +): + """Check that builtin ssl-adapter initialization fails when wrong private key password given.""" + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) + with pytest.raises(ssl.SSLError, match=r'\[SSL\] PEM.+'): + tls_adapter_cls( + certificate=tls_certificate_chain_pem_path, + private_key=tls_certificate_passwd_private_key_pem_path, + private_key_password='x' * 256, + ) + + +@pytest.mark.parametrize( + ('adapter_type', 'false_password', 'expected_warn'), + ( + ( + 'pyopenssl', + '837550fd-bcb9-4320-87e6-09de6456b09', + contextlib.nullcontext(), + ), + ('pyopenssl', 555555, contextlib.nullcontext()), + ( + 'pyopenssl', + '@' * 2048, + pytest.warns( + UserWarning, + match=r'User-provided password is 2048 bytes.+', + ), + ), + ), + ids=('false-password', 'integer-password', 'too-long-password'), +) +def test_openssl_adapter_with_false_key_password( + tls_certificate_chain_pem_path, + tls_certificate_passwd_private_key_pem_path, + adapter_type, + false_password, + expected_warn, +): + """Check that server init fails when wrong private key password given.""" + httpserver = HTTPServer( + bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), + gateway=Gateway, + ) + + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) + tls_adapter = tls_adapter_cls( + certificate=tls_certificate_chain_pem_path, + private_key=tls_certificate_passwd_private_key_pem_path, + private_key_password=false_password, + ) + + httpserver.ssl_adapter = tls_adapter + + with expected_warn, pytest.raises( + OpenSSL.SSL.Error, + match=r'.+bad decrypt.+', + ): + httpserver.prepare() + + assert not httpserver.requests._threads + assert not httpserver.ready diff --git a/docs/changelog-fragments.d/752.feature.rst b/docs/changelog-fragments.d/752.feature.rst new file mode 100644 index 0000000000..e7ea2b905e --- /dev/null +++ b/docs/changelog-fragments.d/752.feature.rst @@ -0,0 +1,2 @@ +Added optional private key password argument to SSL adapters to support password-protected private keys +-- by :user:`jatalahd`.