diff --git a/libp2p/security/exceptions.py b/libp2p/security/exceptions.py index bff09d933..d874ce745 100644 --- a/libp2p/security/exceptions.py +++ b/libp2p/security/exceptions.py @@ -5,3 +5,7 @@ class HandshakeFailure(BaseLibp2pError): pass + + +class SecurityError(BaseLibp2pError): + pass diff --git a/libp2p/security/tls/__init__.py b/libp2p/security/tls/__init__.py new file mode 100644 index 000000000..dd8e3fb90 --- /dev/null +++ b/libp2p/security/tls/__init__.py @@ -0,0 +1,36 @@ +""" +TLS security transport for libp2p. + +This module provides a comprehensive TLS transport implementation +that follows the Go libp2p TLS specification. +""" + +from libp2p.security.tls.transport import ( + TLSTransport, + IdentityConfig, + create_tls_transport, + PROTOCOL_ID, +) +from libp2p.security.tls.io import TLSReadWriter +from libp2p.security.tls.certificate import ( + generate_certificate, + create_cert_template, + verify_certificate_chain, + pub_key_from_cert_chain, + SignedKey, + ALPN_PROTOCOL +) + +__all__ = [ + "TLSTransport", + "IdentityConfig", + "TLSReadWriter", + "create_tls_transport", + "generate_certificate", + "create_cert_template", + "verify_certificate_chain", + "pub_key_from_cert_chain", + "SignedKey", + "PROTOCOL_ID", + "ALPN_PROTOCOL" +] diff --git a/libp2p/security/tls/certificate.py b/libp2p/security/tls/certificate.py new file mode 100644 index 000000000..7921d2f8d --- /dev/null +++ b/libp2p/security/tls/certificate.py @@ -0,0 +1,346 @@ +""" +TLS certificate utilities for libp2p. + +This module provides certificate generation and verification functions +that embed libp2p peer identity information in X.509 extensions. +""" + +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +import os +from typing import Any + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed448, ed25519, rsa +from cryptography.x509.oid import NameOID, ObjectIdentifier + +from libp2p.crypto.keys import PrivateKey, PublicKey +from libp2p.crypto.serialization import deserialize_public_key + +# ALPN protocol for libp2p TLS +ALPN_PROTOCOL = "libp2p" + +# Custom OID for libp2p peer identity extension (same as Rust implementation) +LIBP2P_EXTENSION_OID = ObjectIdentifier("1.3.6.1.4.1.53594.1.1") + +# Prefix used when signing the TLS certificate public key with the libp2p host key +# to bind the X.509 certificate to the libp2p identity. +LIBP2P_CERT_PREFIX: bytes = b"libp2p-tls-handshake:" + + +@dataclass +class SignedKey: + """Represents a signed public key embedded in certificate extension.""" + + public_key_bytes: bytes + signature: bytes + + +def encode_signed_key(public_key_bytes: bytes, signature: bytes) -> bytes: + """ + ASN.1-encode the SignedKey structure for inclusion in the libp2p X.509 extension. + + Args: + public_key_bytes: libp2p protobuf-encoded public key bytes + signature: signature over prefix+certificate public key + + Returns: + DER-encoded bytes representing the SignedKey sequence + + """ + + # DER encoding helpers + def _encode_len(n: int) -> bytes: + if n < 0x80: + return bytes([n]) + length_bytes = n.to_bytes((n.bit_length() + 7) // 8, byteorder="big") + return bytes([0x80 | len(length_bytes)]) + length_bytes + + def _encode_octet_string(data: bytes) -> bytes: + return bytes([0x04]) + _encode_len(len(data)) + data + + def _encode_sequence(content: bytes) -> bytes: + return bytes([0x30]) + _encode_len(len(content)) + content + + content = _encode_octet_string(public_key_bytes) + _encode_octet_string(signature) + return _encode_sequence(content) + + +def decode_signed_key(der_bytes: bytes) -> SignedKey: + """ + Parse DER-encoded SignedKey from the libp2p X.509 extension value. + + Args: + der_bytes: DER bytes for SignedKey + + Returns: + Parsed SignedKey instance + + """ + + # Minimal DER parser for: SEQUENCE { OCTET STRING, OCTET STRING } + def _expect_byte(data: bytes, idx: int, b: int) -> int: + if idx >= len(data) or data[idx] != b: + raise ValueError("Invalid DER: unexpected tag") + return idx + 1 + + def _read_len(data: bytes, idx: int) -> tuple[int, int]: + if idx >= len(data): + raise ValueError("Invalid DER: truncated length") + first = data[idx] + idx += 1 + if first < 0x80: + return first, idx + num = first & 0x7F + if idx + num > len(data): + raise ValueError("Invalid DER: truncated long length") + val = int.from_bytes(data[idx : idx + num], "big") + return val, idx + num + + i = 0 + i = _expect_byte(der_bytes, i, 0x30) # SEQUENCE + seq_len, i = _read_len(der_bytes, i) + end_seq = i + seq_len + + i = _expect_byte(der_bytes, i, 0x04) # OCTET STRING + pk_len, i = _read_len(der_bytes, i) + if i + pk_len > len(der_bytes): + raise ValueError("Invalid DER: truncated public key") + pk_bytes = der_bytes[i : i + pk_len] + i += pk_len + + i = _expect_byte(der_bytes, i, 0x04) # OCTET STRING + sig_len, i = _read_len(der_bytes, i) + if i + sig_len > len(der_bytes): + raise ValueError("Invalid DER: truncated signature") + sig_bytes = der_bytes[i : i + sig_len] + i += sig_len + + if i != end_seq: + raise ValueError("Invalid DER: extra data in SignedKey") + + return SignedKey(public_key_bytes=pk_bytes, signature=sig_bytes) + + +def create_cert_template() -> x509.CertificateBuilder: + """ + Create a certificate template for libp2p TLS certificates. + + Returns: + Certificate builder template + + """ + # Serial: random 64-bit value + serial = int.from_bytes(os.urandom(8), "big") + not_before = datetime.now(timezone.utc) - timedelta(hours=1) + # ~100 years + not_after = not_before + timedelta(days=365 * 100) + + # Create name attributes with explicit typing to satisfy strict type checker + common_name_value: Any = "libp2p" + subject_name = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, common_name_value)] + ) + issuer_name = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, common_name_value)] + ) + + builder = ( + x509.CertificateBuilder() + .serial_number(serial) + .not_valid_before(not_before) + .not_valid_after(not_after) + .subject_name(subject_name) + .issuer_name(issuer_name) + ) + return builder + + +def add_libp2p_extension( + cert_builder: x509.CertificateBuilder, peer_public_key: PublicKey, signature: bytes +) -> x509.CertificateBuilder: + """ + Add libp2p peer identity extension to certificate. + + Args: + cert_builder: Certificate builder to modify + peer_public_key: Peer's public key to embed + signature: Signature over the certificate's public key + + Returns: + Certificate builder with libp2p extension + + """ + sk_der = encode_signed_key(peer_public_key.serialize(), signature) + ext = x509.UnrecognizedExtension(LIBP2P_EXTENSION_OID, sk_der) + return cert_builder.add_extension(ext, critical=False) + + +def generate_certificate( + private_key: PrivateKey, cert_template: x509.CertificateBuilder +) -> tuple[str, str]: + """ + Generate a self-signed certificate with libp2p extensions. + + Args: + private_key: Private key for signing + cert_template: Certificate template + + Returns: + Tuple of (certificate PEM, private key PEM) + + """ + # Generate an ephemeral TLS key (ECDSA P-256) + tls_private_key = ec.generate_private_key(ec.SECP256R1()) + + # Build SignedKey over the certificate's SubjectPublicKeyInfo + spki_der = tls_private_key.public_key().public_bytes( + serialization.Encoding.DER, + serialization.PublicFormat.SubjectPublicKeyInfo, + ) + signature = private_key.sign(LIBP2P_CERT_PREFIX + spki_der) + + builder = cert_template + builder = builder.public_key(tls_private_key.public_key()) + # Self-signed + builder = add_libp2p_extension(builder, private_key.get_public_key(), signature) + certificate = builder.sign( + private_key=tls_private_key, + algorithm=hashes.SHA256(), + ) + + cert_pem = certificate.public_bytes(serialization.Encoding.PEM).decode() + key_pem = tls_private_key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ).decode() + return cert_pem, key_pem + + +def verify_certificate_chain(cert_chain: list[x509.Certificate]) -> PublicKey: + """ + Verify certificate chain and extract peer public key from libp2p extension. + + Args: + cert_chain: List of certificates in the chain + + Returns: + Public key from libp2p extension + + Raises: + SecurityError: If verification fails + + """ + if len(cert_chain) != 1: + raise ValueError("expected one certificates in the chain") + + [cert] = cert_chain + + # 1) Validity window + now = datetime.now(timezone.utc) + not_before = getattr(cert, "not_valid_before_utc", None) + not_after = getattr(cert, "not_valid_after_utc", None) + if not_before is None: + not_before = cert.not_valid_before.replace(tzinfo=timezone.utc) + if not_after is None: + not_after = cert.not_valid_after.replace(tzinfo=timezone.utc) + if not_before > now or not_after < now: + raise ValueError("certificate has expired or is not yet valid") + + # 2) Find libp2p extension + ext_value: bytes | None = None + for idx, ext in enumerate(cert.extensions): + if ext.oid == LIBP2P_EXTENSION_OID: + # Remove from unhandled critical list if necessary by re-creating cert + # object is non-trivial here; we'll just parse value + ext_value = ( + ext.value.value + if isinstance(ext.value, x509.UnrecognizedExtension) + else None + ) + break + if ext_value is None: + raise ValueError("expected certificate to contain the key extension") + + # 3) Verify self-signature of the certificate + pub = cert.public_key() + # Verify self-signature with correct algorithm based on key type + try: + hash_alg = cert.signature_hash_algorithm + if hash_alg is None: + raise ValueError("Certificate signature hash algorithm is None") + + if isinstance(pub, ec.EllipticCurvePublicKey): + pub.verify(cert.signature, cert.tbs_certificate_bytes, ec.ECDSA(hash_alg)) + elif isinstance(pub, rsa.RSAPublicKey): + from cryptography.hazmat.primitives.asymmetric import padding + + pub.verify( + cert.signature, cert.tbs_certificate_bytes, padding.PKCS1v15(), hash_alg + ) + elif isinstance(pub, (ed25519.Ed25519PublicKey, ed448.Ed448PublicKey)): + pub.verify(cert.signature, cert.tbs_certificate_bytes) + elif isinstance(pub, dsa.DSAPublicKey): + pub.verify(cert.signature, cert.tbs_certificate_bytes, hash_alg) + else: + raise ValueError(f"Unsupported key type for verification: {type(pub)}") + except Exception as e: + raise ValueError(f"certificate verification failed: {e}") + + # 4) Verify extension signature + signed = decode_signed_key(ext_value) + host_pub = deserialize_public_key(signed.public_key_bytes) + + spki_der = cert.public_key().public_bytes( + serialization.Encoding.DER, + serialization.PublicFormat.SubjectPublicKeyInfo, + ) + message = LIBP2P_CERT_PREFIX + spki_der + if not host_pub.verify(message, signed.signature): + raise ValueError("signature invalid") + + return host_pub + + +def pub_key_from_cert_chain(cert_chain: list[x509.Certificate]) -> PublicKey: + """ + Extract public key from certificate chain. + + This is an alias for verify_certificate_chain for compatibility. + + Args: + cert_chain: Certificate chain + + Returns: + Public key + + """ + return verify_certificate_chain(cert_chain) + + +def generate_self_signed_cert() -> tuple[ec.EllipticCurvePrivateKey, x509.Certificate]: + """ + Generate a self-signed certificate for testing purposes. + + This is a utility function based on the guide examples. + + Returns: + Tuple of (private key, certificate) + + """ + key = ec.generate_private_key(ec.SECP256R1()) + common_name_value: Any = "py-libp2p" + name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, common_name_value)]) + cert = ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(name) + .public_key(key.public_key()) + .serial_number(int.from_bytes(os.urandom(8), "big")) + .not_valid_before(datetime.now(timezone.utc) - timedelta(days=1)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=30)) + .sign(key, hashes.SHA256()) + ) + return key, cert diff --git a/libp2p/security/tls/io.py b/libp2p/security/tls/io.py new file mode 100644 index 000000000..30d13bc36 --- /dev/null +++ b/libp2p/security/tls/io.py @@ -0,0 +1,258 @@ +""" +TLS I/O utilities for libp2p. + +This module provides TLS-specific message reading and writing functionality, +similar to how noise handles encrypted communication. +""" + +import ssl + +from cryptography import x509 + +from libp2p.abc import IRawConnection +from libp2p.io.abc import EncryptedMsgReadWriter + + +class TLSReadWriter(EncryptedMsgReadWriter): + """ + TLS encrypted message reader/writer. + + This class handles TLS encryption/decryption over a raw connection, + similar to NoiseTransportReadWriter in the noise implementation. + """ + + def __init__( + self, + conn: IRawConnection, + ssl_context: ssl.SSLContext, + server_side: bool = False, + server_hostname: str | None = None, + ): + """ + Initialize TLS reader/writer. + + Args: + conn: Raw connection to wrap + ssl_context: SSL context for TLS operations + server_side: Whether to act as TLS server + server_hostname: Server hostname for client connections + + """ + self.raw_connection = conn + self.ssl_context = ssl_context + self.server_side = server_side + self.server_hostname = server_hostname + # These will be initialized in handshake() and required for operation + self._ssl_socket: ssl.SSLObject + self._in_bio: ssl.MemoryBIO + self._out_bio: ssl.MemoryBIO + self._peer_certificate: x509.Certificate | None = None + self._handshake_complete = False + self._negotiated_protocol: str | None = None + + async def handshake(self) -> None: + """ + Perform TLS handshake. + + Raises: + HandshakeFailure: If handshake fails + + """ + # Perform a blocking-style TLS handshake using memory BIOs bridged to + # Trio stream + in_bio = ssl.MemoryBIO() + out_bio = ssl.MemoryBIO() + ssl_obj = self.ssl_context.wrap_bio( + in_bio, + out_bio, + server_side=self.server_side, + server_hostname=self.server_hostname, + ) + self._ssl_socket = ssl_obj + self._in_bio = in_bio + self._out_bio = out_bio + + # Drive the handshake + while True: + try: + ssl_obj.do_handshake() + break + except ssl.SSLWantReadError: + # flush data to wire + data = out_bio.read() + if data: + await self.raw_connection.write(data) + # read more from wire + incoming = await self.raw_connection.read(4096) + if incoming: + in_bio.write(incoming) + except ssl.SSLWantWriteError: + data = out_bio.read() + if data: + await self.raw_connection.write(data) + except ssl.SSLCertVerificationError: + # Ignore built-in verification errors; we verify manually afterwards. + break + + # Flush any remaining handshake data + data = out_bio.read() + if data: + await self.raw_connection.write(data) + + # Populate cert and ALPN + # For our usage we skip builtin verification, so peer cert may be self-signed. + # Use binary form if available; otherwise use text form unsupported. + try: + cert_bin = ssl_obj.getpeercert(binary_form=True) + except TypeError: + cert_bin = None + if cert_bin: + self._peer_certificate = x509.load_der_x509_certificate(cert_bin) + self._negotiated_protocol = ssl_obj.selected_alpn_protocol() + self._handshake_complete = True + + def get_peer_certificate(self) -> x509.Certificate | None: + """ + Get the peer's certificate after handshake. + + Returns: + Peer certificate or None if not available + + """ + return self._peer_certificate + + async def write_msg(self, msg: bytes) -> None: + """ + Write an encrypted message. + + Args: + msg: Message to encrypt and send + + """ + # Ensure handshake was called + if not self._handshake_complete: + raise RuntimeError("Call handshake() first") + # write plaintext into SSL object and flush ciphertext to transport + remaining = msg + while remaining: + try: + n = self._ssl_socket.write(remaining) + remaining = remaining[n:] + except ssl.SSLWantWriteError: + pass + # flush any TLS records produced + while True: + data = self._out_bio.read() + if not data: + break + await self.raw_connection.write(data) + + async def read_msg(self) -> bytes: + """ + Read and decrypt a message. + + Returns: + Decrypted message bytes + + """ + # Ensure handshake was called + if not self._handshake_complete: + raise RuntimeError("Call handshake() first") + + # Try to read decrypted application data; if need more TLS bytes, + # fetch from network + max_attempts = 100 # Prevent infinite loops + attempt = 0 + + while attempt < max_attempts: + attempt += 1 + try: + data = self._ssl_socket.read(65536) + if data: + return data + # If we get here, ssl_socket.read() returned empty data + # Check if connection is closed by trying to read from raw connection + try: + incoming = await self.raw_connection.read(4096) + if not incoming: + return b"" # Connection closed + self._in_bio.write(incoming) + continue # Try reading again with new data + except Exception: + return b"" # Connection error + except ssl.SSLWantReadError: + # flush any pending TLS data + pending = self._out_bio.read() + if pending: + await self.raw_connection.write(pending) + # get more ciphertext + incoming = await self.raw_connection.read(4096) + if not incoming: + return b"" + self._in_bio.write(incoming) + continue + except Exception: + # Any other SSL error - connection is likely broken + return b"" + + # If we've exhausted attempts, return empty + return b"" + + def encrypt(self, data: bytes) -> bytes: + """ + Encrypt data for transmission. + + Args: + data: Data to encrypt + + Returns: + Encrypted data + + """ + # In TLS, encryption is handled at the SSL layer during write_msg + # This method exists for interface compatibility + return data + + def decrypt(self, data: bytes) -> bytes: + """ + Decrypt received data. + + Args: + data: Encrypted data to decrypt + + Returns: + Decrypted data + + """ + # In TLS, decryption is handled at the SSL layer during read_msg + # This method exists for interface compatibility + return data + + async def close(self) -> None: + """Close the TLS connection.""" + try: + if self._ssl_socket is not None: + try: + self._ssl_socket.unwrap() + except Exception: + pass + finally: + await self.raw_connection.close() + + def get_negotiated_protocol(self) -> str | None: + """ + Return the ALPN-negotiated protocol (e.g., selected muxer) if any. + """ + return self._negotiated_protocol + + def get_remote_address(self) -> tuple[str, int] | None: + """ + Get remote address from underlying connection. + + Returns: + Remote address tuple or None + + """ + if hasattr(self.raw_connection, "get_remote_address"): + return self.raw_connection.get_remote_address() + return None diff --git a/libp2p/security/tls/transport.py b/libp2p/security/tls/transport.py new file mode 100644 index 000000000..8cccc96f7 --- /dev/null +++ b/libp2p/security/tls/transport.py @@ -0,0 +1,287 @@ +from dataclasses import dataclass +import ssl +from typing import Any + +from cryptography import x509 + +from libp2p.abc import IRawConnection, ISecureConn, ISecureTransport +from libp2p.crypto.keys import KeyPair, PrivateKey +from libp2p.custom_types import TProtocol +from libp2p.peer.id import ID +from libp2p.security.secure_session import SecureSession +from libp2p.security.tls.certificate import ( + ALPN_PROTOCOL, + create_cert_template, + generate_certificate, + verify_certificate_chain, +) +from libp2p.security.tls.io import TLSReadWriter + +# Protocol ID for TLS transport +PROTOCOL_ID = TProtocol("/tls/1.0.0") + + +@dataclass +class IdentityConfig: + """Configuration for TLS identity.""" + + cert_template: x509.CertificateBuilder | None = None + key_log_writer: Any | None = None + + +class TLSTransport(ISecureTransport): + """ + TLS transport implementation following the noise pattern. + + Features: + - TLS 1.3 support + - Custom certificate generation with libp2p extensions + - Peer ID verification + - ALPN protocol negotiation + """ + + libp2p_privkey: PrivateKey + local_peer: ID + early_data: bytes | None + + def __init__( + self, + libp2p_keypair: KeyPair, + early_data: bytes | None = None, + muxers: list[str] | None = None, + identity_config: IdentityConfig | None = None, + ): + """Initialize TLS transport.""" + self.libp2p_privkey = libp2p_keypair.private_key + self.local_peer = ID.from_pubkey(libp2p_keypair.public_key) + self.early_data = early_data + # Optional list of preferred stream muxers for ALPN negotiation. + self._preferred_muxers = muxers or [] + # Optional identity config for certificate template and key log writer. + self._identity_config = identity_config + # Generate and cache a stable identity certificate for this transport + template = ( + self._identity_config.cert_template + if self._identity_config and self._identity_config.cert_template + else create_cert_template() + ) + self._cert_pem, self._key_pem = generate_certificate( + self.libp2p_privkey, template + ) + # Trusted peer certs (PEM) for accepting self-signed peers during tests + self._trusted_peer_certs_pem: list[str] = [] + + def create_ssl_context(self, server_side: bool = False) -> ssl.SSLContext: + """ + Create SSL context for TLS connections. + + Args: + server_side: Whether this is for server-side connections + + Returns: + Configured SSL context + + """ + # Placeholder for SSL context creation following libp2p TLS 1.3 profile. + # Expected responsibilities: + # - TLS 1.3 only + # - Insecure cert verification here, custom verification post-handshake + # - Set ALPN protocols: preferred muxers + "libp2p" + # - Apply key log writer if provided in identity_config + # - Disable SNI for client-side connections + ctx = ssl.SSLContext( + ssl.PROTOCOL_TLS_SERVER if server_side else ssl.PROTOCOL_TLS_CLIENT + ) + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + # We do our own verification of the peer certificate + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_OPTIONAL if server_side else ssl.CERT_NONE + + # Load our cached self-signed certificate bound to libp2p identity + import tempfile + + with tempfile.NamedTemporaryFile("w", delete=False) as cert_file: + cert_file.write(self._cert_pem) + cert_path = cert_file.name + with tempfile.NamedTemporaryFile("w", delete=False) as key_file: + key_file.write(self._key_pem) + key_path = key_file.name + ctx.load_cert_chain(certfile=cert_path, keyfile=key_path) + + # If we have trusted peer certs, configure verification to accept those + if server_side and self._trusted_peer_certs_pem: + with tempfile.NamedTemporaryFile("w", delete=False) as cafile: + cafile.write("".join(self._trusted_peer_certs_pem)) + ca_path = cafile.name + try: + ctx.load_verify_locations(cafile=ca_path) + ctx.verify_mode = ssl.CERT_OPTIONAL + except Exception: + pass + + # ALPN: provide list; without a select-callback we accept server preference. + alpn_list = list(self._preferred_muxers) + [ALPN_PROTOCOL] + try: + ctx.set_alpn_protocols(alpn_list) + except Exception: + # ALPN may be unavailable; proceed without early muxer negotiation + pass + + # key log file support if provided as path-like + if self._identity_config and self._identity_config.key_log_writer: + # Accept a file path or a file-like with name + keylog_path = None + writer = self._identity_config.key_log_writer + if isinstance(writer, str): + keylog_path = writer + elif hasattr(writer, "name"): + keylog_path = getattr(writer, "name") + if keylog_path: + try: + ctx.keylog_filename = keylog_path + except Exception: + pass + + return ctx + + async def secure_inbound(self, conn: IRawConnection) -> ISecureConn: + """ + Secure an inbound connection as server. + + Args: + conn: Raw connection to secure + + Returns: + Secured connection (SecureSession) + + """ + # Create SSL context for server + ssl_context = self.create_ssl_context(server_side=True) + + # Create TLS reader/writer + tls_reader_writer = TLSReadWriter( + conn=conn, ssl_context=ssl_context, server_side=True + ) + + # Perform handshake + await tls_reader_writer.handshake() + + # Extract peer information + peer_cert = tls_reader_writer.get_peer_certificate() + if not peer_cert: + raise ValueError("missing peer certificate") + + # Extract remote public key from certificate + remote_public_key = self._extract_public_key_from_cert(peer_cert) + remote_peer_id = ID.from_pubkey(remote_public_key) + + # Return SecureSession like noise does + return SecureSession( + local_peer=self.local_peer, + local_private_key=self.libp2p_privkey, + remote_peer=remote_peer_id, + remote_permanent_pubkey=remote_public_key, + is_initiator=False, + conn=tls_reader_writer, + ) + + async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn: + """ + Secure an outbound connection as client. + + Args: + conn: Raw connection to secure + peer_id: Expected peer ID + + Returns: + Secured connection (SecureSession) + + """ + # Create SSL context for client + ssl_context = self.create_ssl_context(server_side=False) + + # Create TLS reader/writer + tls_reader_writer = TLSReadWriter( + conn=conn, ssl_context=ssl_context, server_side=False + ) + + # Perform handshake + await tls_reader_writer.handshake() + + # Extract peer information + peer_cert = tls_reader_writer.get_peer_certificate() + if not peer_cert: + raise ValueError("missing peer certificate") + + # Extract and verify remote public key + remote_public_key = self._extract_public_key_from_cert(peer_cert) + remote_peer_id = ID.from_pubkey(remote_public_key) + + if remote_peer_id != peer_id: + raise ValueError( + f"Peer ID mismatch: expected {peer_id} got {remote_peer_id}" + ) + + # Return SecureSession like noise does + return SecureSession( + local_peer=self.local_peer, + local_private_key=self.libp2p_privkey, + remote_peer=peer_id, + remote_permanent_pubkey=remote_public_key, + is_initiator=True, + conn=tls_reader_writer, + ) + + def _extract_public_key_from_cert(self, cert: x509.Certificate) -> Any: + """Extract public key from certificate.""" + # Use our chain verifier to extract the host public key + return verify_certificate_chain([cert]) + + def get_protocol_id(self) -> TProtocol: + """Get the protocol ID for this transport.""" + return PROTOCOL_ID + + def get_preferred_muxers(self) -> list[str]: + """ + Return the list of preferred stream muxers for ALPN negotiation. + """ + return list(self._preferred_muxers) + + def get_negotiated_muxer(self) -> str | None: + """ + Placeholder: return the muxer negotiated via ALPN, if any. + """ + # Negotiated muxer is available from the TLSReadWriter after handshake. + # It's surfaced on the SecureSession via connection state in other impls. + # For now, not exposed at this layer. + return None + + # Expose local certificate for tests + def get_certificate_pem(self) -> str: + return self._cert_pem + + def trust_peer_cert_pem(self, pem: str) -> None: + self._trusted_peer_certs_pem.append(pem) + + +# Factory function for creating TLS transport +def create_tls_transport( + libp2p_keypair: KeyPair, + early_data: bytes | None = None, + muxers: list[str] | None = None, + identity_config: IdentityConfig | None = None, +) -> TLSTransport: + """ + Create a new TLS transport. + + Args: + libp2p_keypair: Key pair for the local peer + early_data: Optional early data for TLS handshake + muxers: Optional list of preferred stream muxer protocol IDs for ALPN + identity_config: Optional TLS identity config (cert template, key log writer) + + Returns: + TLS transport instance + + """ + return TLSTransport(libp2p_keypair, early_data, muxers, identity_config) diff --git a/tests/core/security/tls/test_certificate.py b/tests/core/security/tls/test_certificate.py new file mode 100644 index 000000000..61f34ba2e --- /dev/null +++ b/tests/core/security/tls/test_certificate.py @@ -0,0 +1,44 @@ +from cryptography import x509 +from cryptography.x509.oid import ObjectIdentifier + +from libp2p.crypto.secp256k1 import create_new_key_pair as create_secp256k1 +from libp2p.security.tls import certificate as certmod + + +def test_signedkey_asn1_roundtrip(): + pub = b"pub-bytes-example" + sig = b"sig-bytes-example" + der = certmod.encode_signed_key(pub, sig) + sk = certmod.decode_signed_key(der) + assert sk.public_key_bytes == b"pub-bytes-example" + assert sk.signature == sig + + +def test_generate_certificate_has_libp2p_extension_noncritical(): + keypair = create_secp256k1() + tmpl = certmod.create_cert_template() + cert_pem, _ = certmod.generate_certificate(keypair.private_key, tmpl) + cert = x509.load_pem_x509_certificate(cert_pem.encode()) + + # Find extension + found = False + for ext in cert.extensions: + if ext.oid == ObjectIdentifier("1.3.6.1.4.1.53594.1.1"): + found = True + assert ext.critical is False + sk = certmod.decode_signed_key(ext.value.value) + assert isinstance(sk.public_key_bytes, (bytes, bytearray)) + assert isinstance(sk.signature, (bytes, bytearray)) + break + assert found, "libp2p extension missing" + + +def test_verify_certificate_chain_extracts_host_public_key(): + # Generate cert for a new host key + keypair = create_secp256k1() + tmpl = certmod.create_cert_template() + cert_pem, _ = certmod.generate_certificate(keypair.private_key, tmpl) + cert = x509.load_pem_x509_certificate(cert_pem.encode()) + + pub = certmod.verify_certificate_chain([cert]) + assert pub.to_bytes() == keypair.public_key.to_bytes() diff --git a/tests/core/security/tls/test_tls.py b/tests/core/security/tls/test_tls.py new file mode 100644 index 000000000..dcefbce2b --- /dev/null +++ b/tests/core/security/tls/test_tls.py @@ -0,0 +1,62 @@ +import pytest + +from libp2p import generate_new_rsa_identity +from libp2p.security.tls.transport import TLSTransport +from tests.utils.factories import tls_conn_factory + + +@pytest.mark.trio +async def test_tls_basic_handshake(nursery): + keypair_a = generate_new_rsa_identity() + keypair_b = generate_new_rsa_identity() + + t_a = TLSTransport(keypair_a) + t_b = TLSTransport(keypair_b) + # Trust each other's certs during tests to avoid system verify failure + t_a.trust_peer_cert_pem(t_b.get_certificate_pem()) + t_b.trust_peer_cert_pem(t_a.get_certificate_pem()) + + async with tls_conn_factory( + nursery, client_transport=t_a, server_transport=t_b + ) as ( + client_conn, + server_conn, + ): + assert client_conn.get_local_peer() == t_a.local_peer + assert server_conn.get_local_peer() == t_b.local_peer + assert client_conn.get_remote_peer() == t_b.local_peer + assert server_conn.get_remote_peer() == t_a.local_peer + + await server_conn.write(b"hello") + assert await client_conn.read(5) == b"hello" + + await client_conn.write(b"world") + assert await server_conn.read(5) == b"world" + + await client_conn.close() + await server_conn.close() + + +DATA_0 = b"hello" +DATA_1 = b"x" * 1500 +DATA_2 = b"bye!" + + +@pytest.mark.trio +async def test_tls_transport(nursery): + async with tls_conn_factory(nursery): + # handshake succeeds if factory returns + pass + + +@pytest.mark.trio +async def test_tls_connection(nursery): + async with tls_conn_factory(nursery) as (local, remote): + await local.write(DATA_0) + await local.write(DATA_1) + + assert DATA_0 == await remote.read(len(DATA_0)) + assert DATA_1 == await remote.read(len(DATA_1)) + + await local.write(DATA_2) + assert DATA_2 == await remote.read(len(DATA_2)) diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 75639e369..2531cc0b1 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -93,6 +93,10 @@ Transport as NoiseTransport, ) import libp2p.security.secio.transport as secio +from libp2p.security.tls.transport import ( + PROTOCOL_ID as TLS_PROTOCOL_ID, + TLSTransport, +) from libp2p.stream_muxer.mplex.mplex import ( MPLEX_PROTOCOL_ID, Mplex, @@ -192,6 +196,12 @@ def security_options_factory(key_pair: KeyPair) -> TSecurityOptions: transport_factory = secio_transport_factory elif protocol_id == NOISE_PROTOCOL_ID: transport_factory = noise_transport_factory + elif protocol_id == TLS_PROTOCOL_ID: + + def tls_transport_factory(key_pair): + return TLSTransport(key_pair) + + transport_factory = tls_transport_factory else: raise Exception(f"security transport {protocol_id} is not supported") return {protocol_id: transport_factory(key_pair)} @@ -277,6 +287,42 @@ async def upgrade_remote_conn() -> None: yield local_secure_conn, remote_secure_conn +@asynccontextmanager +async def tls_conn_factory( + nursery: trio.Nursery, + client_transport: TLSTransport | None = None, + server_transport: TLSTransport | None = None, +) -> AsyncIterator[tuple[ISecureConn, ISecureConn]]: + local_transport = client_transport or TLSTransport(create_secp256k1_key_pair()) + remote_transport = server_transport or TLSTransport(create_secp256k1_key_pair()) + # Trust each other's certs for test handshake + local_transport.trust_peer_cert_pem(remote_transport.get_certificate_pem()) + remote_transport.trust_peer_cert_pem(local_transport.get_certificate_pem()) + + local_secure_conn: ISecureConn | None = None + remote_secure_conn: ISecureConn | None = None + + async def upgrade_local_conn(local_conn: IRawConnection) -> None: + nonlocal local_secure_conn + local_secure_conn = await local_transport.secure_outbound( + local_conn, remote_transport.local_peer + ) + + async def upgrade_remote_conn(remote_conn: IRawConnection) -> None: + nonlocal remote_secure_conn + remote_secure_conn = await remote_transport.secure_inbound(remote_conn) + + async with raw_conn_factory(nursery) as (local_conn, remote_conn): + async with trio.open_nursery() as n: + n.start_soon(upgrade_local_conn, local_conn) + n.start_soon(upgrade_remote_conn, remote_conn) + if local_secure_conn is None or remote_secure_conn is None: + raise Exception( + "local or remote secure conn has not been successfully upgraded" + ) + yield local_secure_conn, remote_secure_conn + + class SwarmFactory(factory.Factory): class Meta: model = Swarm