|
3 | 3 | import hashlib
|
4 | 4 | import logging
|
5 | 5 | from typing import Any
|
6 |
| - |
| 6 | +import trio |
7 | 7 | import base58
|
8 | 8 | from cryptography import (
|
9 | 9 | x509,
|
|
16 | 16 | serialization,
|
17 | 17 | )
|
18 | 18 | from cryptography.hazmat.primitives.asymmetric import (
|
19 |
| - rsa, |
| 19 | + ec, |
20 | 20 | )
|
21 | 21 | from cryptography.hazmat.primitives.asymmetric.rsa import (
|
22 | 22 | RSAPrivateKey as CryptoRSAPrivateKey,
|
|
37 | 37 | ID,
|
38 | 38 | )
|
39 | 39 |
|
| 40 | +from ..constants import ( |
| 41 | + DEFAULT_CERTIFICATE_RENEWAL_THRESHOLD, |
| 42 | + DEFAULT_CERTIFICATE_LIFESPAN |
| 43 | +) |
40 | 44 | SIGNAL_PROTOCOL = "/libp2p/webrtc/signal/1.0.0"
|
41 | 45 | logger = logging.getLogger("libp2p.transport.webrtc.certificate")
|
42 | 46 |
|
43 |
| - |
| 47 | +# TODO: Once Datastore is implemented in python, add cert and priv_key storage |
| 48 | +# and management. |
44 | 49 | class WebRTCCertificate:
|
45 | 50 | """WebRTC certificate for connections"""
|
46 | 51 |
|
47 |
| - def __init__(self, cert: x509.Certificate, private_key: rsa.RSAPrivateKey) -> None: |
| 52 | + def __init__(self, cert: x509.Certificate, private_key: ec.EllipticCurvePrivateKey) -> None: |
48 | 53 | self.cert = cert
|
49 |
| - self.private_key = private_key |
| 54 | + self.private_key = private_key | None = None |
50 | 55 | self._fingerprint: str | None = None
|
51 | 56 | self._certhash: str | None = None
|
52 |
| - |
| 57 | + self.cancel_scope: trio.CancelScope = None |
53 | 58 | @classmethod
|
54 | 59 | def generate(cls) -> "WebRTCCertificate":
|
55 | 60 | """Generate a new self-signed certificate for WebRTC"""
|
56 |
| - # Generate private key |
57 |
| - private_key = rsa.generate_private_key( |
58 |
| - public_exponent=65537, |
59 |
| - key_size=2048, |
60 |
| - ) |
61 |
| - |
| 61 | + # Create instance first with None private key |
| 62 | + instance = cls.__new__(cls) |
| 63 | + instance._fingerprint = None |
| 64 | + instance._certhash = None |
| 65 | + |
| 66 | + # Generate private key using the instance method |
| 67 | + private_key = instance.loadOrCreatePrivateKey() |
| 68 | + |
62 | 69 | # Create certificate
|
63 |
| - common_name: Any = "libp2p-webrtc" |
64 |
| - subject = issuer = x509.Name( |
65 |
| - [ |
66 |
| - x509.NameAttribute(NameOID.COMMON_NAME, common_name), |
67 |
| - ] |
68 |
| - ) |
69 |
| - |
70 |
| - cert = ( |
71 |
| - x509.CertificateBuilder() |
72 |
| - .subject_name(subject) |
73 |
| - .issuer_name(issuer) |
74 |
| - .public_key(private_key.public_key()) |
75 |
| - .serial_number(x509.random_serial_number()) |
76 |
| - .not_valid_before(datetime.datetime.utcnow()) |
77 |
| - .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) |
78 |
| - .add_extension( |
79 |
| - x509.SubjectAlternativeName( |
80 |
| - [ |
81 |
| - x509.DNSName("localhost"), |
82 |
| - ] |
83 |
| - ), |
84 |
| - critical=False, |
85 |
| - ) |
86 |
| - .sign(private_key, hashes.SHA256()) |
87 |
| - ) |
88 |
| - |
89 |
| - return cls(cert, private_key) |
| 70 | + cert, pem = instance.loadOrCreateCertificate() |
| 71 | + |
| 72 | + # Set the certificate and private key on the instance |
| 73 | + instance.cert = cert |
| 74 | + instance.private_key = private_key |
| 75 | + |
| 76 | + return instance |
90 | 77 |
|
91 | 78 | @property
|
92 | 79 | def fingerprint(self) -> str:
|
@@ -208,7 +195,111 @@ def validate_pem_export(self) -> bool:
|
208 | 195 | raise ValueError("Invalid private key PEM footer")
|
209 | 196 |
|
210 | 197 | return True
|
| 198 | + |
| 199 | + def _getCertRenewalTime(self) -> int: |
| 200 | + # Calculate the renewal time in milliseconds until certificate expiry minus the renewal threshold. |
| 201 | + renew_at = self.cert.not_valid_after - datetime.timedelta(milliseconds=DEFAULT_CERTIFICATE_RENEWAL_THRESHOLD) |
| 202 | + now = datetime.datetime.now(datetime.timezone.utc) |
| 203 | + renewal_time_ms = int((renew_at - now).total_seconds() * 1000) |
| 204 | + return renewal_time_ms if renewal_time_ms > 0 else 100 |
| 205 | + |
| 206 | + |
| 207 | + def loadOrCreatePrivateKey(self, forceRenew = False) -> ec.EllipticCurvePrivateKey: |
| 208 | + """ |
| 209 | + Load the existing private key if available, or generate a new one. |
| 210 | +
|
| 211 | + Args: |
| 212 | + forceRenew (bool): If True, always generate a new private key even if one already exists. |
| 213 | + If False, return the existing private key if present. |
211 | 214 |
|
| 215 | + Returns: |
| 216 | + ec.EllipticCurvePrivateKey: The loaded or newly generated elliptic curve private key. |
| 217 | + """ |
| 218 | + # If private key is already present and not enforced to create new |
| 219 | + if self.private_key != None and not forceRenew: |
| 220 | + return self.private_key |
| 221 | + |
| 222 | + # Create a new private key |
| 223 | + self.private_key = ec.generate_private_key(ec.SECP256R1()) |
| 224 | + return self.private_key |
| 225 | + |
| 226 | + def loadOrCreateCertificate( |
| 227 | + self, |
| 228 | + private_key: ec.EllipticCurvePrivateKey | None, |
| 229 | + forceRenew: bool = False |
| 230 | + ) -> tuple[x509.Certificate, str, str]: |
| 231 | + """ |
| 232 | + Generate or load a self-signed WebRTC certificate for libp2p direct connections. |
| 233 | +
|
| 234 | + If a valid certificate already exists and is not expired, and the public key matches, |
| 235 | + it will be reused unless forceRenew is True. Otherwise, a new certificate is generated. |
| 236 | +
|
| 237 | + Args: |
| 238 | + private_key (ec.EllipticCurvePrivateKey | None): The private key to use for signing the certificate. |
| 239 | + If None, uses self.private_key. |
| 240 | + forceRenew (bool): If True, always generate a new certificate even if the current one is valid. |
| 241 | +
|
| 242 | + Returns: |
| 243 | + tuple[x509.Certificate, str, str]: The certificate object, its PEM-encoded string, and the base64url-encoded SHA-256 hash of the certificate. |
| 244 | +
|
| 245 | + Raises: |
| 246 | + Exception: If no private key is available to issue a certificate. |
| 247 | + """ |
| 248 | + if private_key is None: |
| 249 | + if self.private_key is None: |
| 250 | + raise Exception("Can't issue certificate without private key") |
| 251 | + private_key = self.private_key |
| 252 | + |
| 253 | + if self.cert is not None and not forceRenew: |
| 254 | + # Check if certificate has to be renewed |
| 255 | + renewal_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(milliseconds=DEFAULT_CERTIFICATE_RENEWAL_THRESHOLD) |
| 256 | + isExpired = renewal_time >= self.cert.not_valid_after |
| 257 | + if not isExpired: |
| 258 | + # Check if the certificate's public key matches with provided key pair |
| 259 | + if self.cert.public_key().public_numbers() == private_key.public_key().public_numbers(): |
| 260 | + cert_pem, _ = self.to_pem() |
| 261 | + cert_hash = self.certhash() |
| 262 | + return (self.cert, cert_pem, cert_hash) |
| 263 | + |
| 264 | + common_name: str = "libp2p-webrtc" |
| 265 | + subject = issuer = x509.Name( |
| 266 | + [ |
| 267 | + x509.NameAttribute(NameOID.COMMON_NAME, common_name), |
| 268 | + ] |
| 269 | + ) |
| 270 | + |
| 271 | + cert = ( |
| 272 | + x509.CertificateBuilder() |
| 273 | + .subject_name(subject) |
| 274 | + .issuer_name(issuer) |
| 275 | + .public_key(private_key.public_key()) |
| 276 | + .serial_number(x509.random_serial_number()) |
| 277 | + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) |
| 278 | + .not_valid_after( |
| 279 | + datetime.datetime.now(datetime.timezone.utc) + |
| 280 | + datetime.timedelta(milliseconds=DEFAULT_CERTIFICATE_LIFESPAN) |
| 281 | + ) |
| 282 | + .add_extension( |
| 283 | + x509.SubjectAlternativeName( |
| 284 | + [ |
| 285 | + x509.DNSName("localhost"), |
| 286 | + ] |
| 287 | + ), |
| 288 | + critical=False, |
| 289 | + ) |
| 290 | + .sign(private_key, hashes.SHA256()) |
| 291 | + ) |
| 292 | + self.cert = cert |
| 293 | + pem = cert.public_bytes(Encoding.PEM).decode('utf-8') |
| 294 | + cert_pem, _ = self.to_pem() |
| 295 | + cert_hash = self.certhash() |
| 296 | + return (cert, cert_pem, cert_hash) |
| 297 | + |
| 298 | + async def renewal_loop(self): |
| 299 | + while True: |
| 300 | + await trio.sleep(self._getCertRenewalTime) |
| 301 | + logger.Debug("Renewing TLS certificate") |
| 302 | + await self.loadOrCreateCertificate(self.private_key, True) |
212 | 303 |
|
213 | 304 | def create_webrtc_multiaddr(
|
214 | 305 | ip: str, peer_id: ID, certhash: str, direct: bool = False
|
|
0 commit comments