Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cheroot/ssl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ def __init__(
private_key,
certificate_chain=None,
ciphers=None,
*,
private_key_password=None,
):
"""Set up certificates, private key ciphers and reset context."""
"""Set up certificates, private key, ciphers and reset context."""
self.certificate = certificate
self.private_key = private_key
self.certificate_chain = certificate_chain
self.ciphers = ciphers
self.private_key_password = private_key_password
self.context = None

@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions cheroot/ssl/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class Adapter(ABC):
private_key: Any
certificate_chain: Any
ciphers: Any
private_key_password: str | bytes | None
context: Any
@abstractmethod
def __init__(
Expand All @@ -14,6 +15,8 @@ class Adapter(ABC):
private_key,
certificate_chain: Any | None = ...,
ciphers: Any | None = ...,
*,
private_key_password: str | bytes | None = ...,
): ...
@abstractmethod
def bind(self, sock): ...
Expand Down
48 changes: 42 additions & 6 deletions cheroot/ssl/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,20 @@ def _loopback_for_cert_thread(context, server):
ssl_sock.send(b'0000')


def _loopback_for_cert(certificate, private_key, certificate_chain):
def _loopback_for_cert(
certificate,
private_key,
certificate_chain,
*,
private_key_password=None,
):
"""Create a loopback connection to parse a cert with a private key."""
context = ssl.create_default_context(cafile=certificate_chain)
context.load_cert_chain(certificate, private_key)
context.load_cert_chain(
certificate,
private_key,
password=private_key_password,
)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE

Expand Down Expand Up @@ -112,15 +122,26 @@ def _loopback_for_cert(certificate, private_key, certificate_chain):
server.close()


def _parse_cert(certificate, private_key, certificate_chain):
def _parse_cert(
certificate,
private_key,
certificate_chain,
*,
private_key_password=None,
):
"""Parse a certificate."""
# loopback_for_cert uses socket.socketpair which was only
# introduced in Python 3.0 for *nix and 3.5 for Windows
# and requires OS support (AttributeError, OSError)
# it also requires a private key either in its own file
# or combined with the cert (SSLError)
with suppress(AttributeError, ssl.SSLError, OSError):
return _loopback_for_cert(certificate, private_key, certificate_chain)
return _loopback_for_cert(
certificate,
private_key,
certificate_chain,
private_key_password=private_key_password,
)

# KLUDGE: using an undocumented, private, test method to parse a cert
# unfortunately, it is the only built-in way without a connection
Expand Down Expand Up @@ -153,6 +174,9 @@ class BuiltinSSLAdapter(Adapter):
ciphers = None
"""The ciphers list of SSL."""

private_key_password = None
"""Optional passphrase for password protected private key."""

# from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert
CERT_KEY_TO_ENV = {
'version': 'M_VERSION',
Expand Down Expand Up @@ -208,6 +232,8 @@ def __init__(
private_key,
certificate_chain=None,
ciphers=None,
*,
private_key_password=None,
):
"""Set up context in addition to base class properties if available."""
if ssl is None:
Expand All @@ -218,19 +244,29 @@ def __init__(
private_key,
certificate_chain,
ciphers,
private_key_password=private_key_password,
)

self.context = ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH,
cafile=certificate_chain,
)
self.context.load_cert_chain(certificate, private_key)
self.context.load_cert_chain(
certificate,
private_key,
password=private_key_password,
)
if self.ciphers is not None:
self.context.set_ciphers(ciphers)

self._server_env = self._make_env_cert_dict(
'SSL_SERVER',
_parse_cert(certificate, private_key, self.certificate_chain),
_parse_cert(
certificate,
private_key,
self.certificate_chain,
private_key_password=private_key_password,
),
)
if not self._server_env:
return
Expand Down
2 changes: 2 additions & 0 deletions cheroot/ssl/builtin.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class BuiltinSSLAdapter(Adapter):
private_key,
certificate_chain: Any | None = ...,
ciphers: Any | None = ...,
*,
private_key_password: str | bytes | None = ...,
) -> None: ...
@property
def context(self): ...
Expand Down
33 changes: 33 additions & 0 deletions cheroot/ssl/pyopenssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import sys
import threading
import time
from warnings import warn as _warn


try:
Expand Down Expand Up @@ -293,12 +294,17 @@ class pyOpenSSLAdapter(Adapter):
ciphers = None
"""The ciphers list of TLS."""

private_key_password = None
"""Optional passphrase for password protected private key."""

def __init__(
self,
certificate,
private_key,
certificate_chain=None,
ciphers=None,
*,
private_key_password=None,
):
"""Initialize OpenSSL Adapter instance."""
if SSL is None:
Expand All @@ -309,6 +315,7 @@ def __init__(
private_key,
certificate_chain,
ciphers,
private_key_password=private_key_password,
)

self._environ = None
Expand All @@ -328,13 +335,39 @@ def wrap(self, sock):
# closing so we can't reliably access protocol/client cert for the env
return sock, self._environ.copy()

def _password_callback(
self,
password_max_length,
_verify_twice,
password,
/,
):
"""Pass a passphrase to password protected private key."""
b_password = b''
if isinstance(password, str):
b_password = password.encode('utf-8')
elif isinstance(password, bytes):
b_password = password

password_length = len(b_password)
if password_length > password_max_length:
_warn(
f'User-provided password is {password_length} bytes long and will '
f'be truncated since it exceeds the maximum of {password_max_length}.',
UserWarning,
stacklevel=1,
)

return b_password

def get_context(self):
"""Return an ``SSL.Context`` from self attributes.

Ref: :py:class:`SSL.Context <pyopenssl:OpenSSL.SSL.Context>`
"""
# See https://code.activestate.com/recipes/442473/
c = SSL.Context(SSL.SSLv23_METHOD)
c.set_passwd_cb(self._password_callback, self.private_key_password)
c.use_privatekey_file(self.private_key)
if self.certificate_chain:
c.load_verify_locations(self.certificate_chain)
Expand Down
9 changes: 9 additions & 0 deletions cheroot/ssl/pyopenssl.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,18 @@ class pyOpenSSLAdapter(Adapter):
private_key,
certificate_chain: Any | None = ...,
ciphers: Any | None = ...,
*,
private_key_password: str | bytes | None = ...,
) -> None: ...
def bind(self, sock): ...
def wrap(self, sock): ...
def _password_callback(
self,
password_max_length: int,
_verify_twice: bool,
password: bytes | str | None,
/,
) -> bytes: ...
def get_environ(self): ...
def makefile(self, sock, mode: str = ..., bufsize: int = ...): ...
def get_context(self) -> SSL.Context: ...
Loading
Loading