diff --git a/lib/charms/mongodb/v0/mongodb_tls.py b/lib/charms/mongodb/v0/mongodb_tls.py index b128bdbef..b9ead22cc 100644 --- a/lib/charms/mongodb/v0/mongodb_tls.py +++ b/lib/charms/mongodb/v0/mongodb_tls.py @@ -13,10 +13,10 @@ import socket from typing import List, Optional, Tuple -from charms.tls_certificates_interface.v3.tls_certificates import ( +from charms.tls_certificates_interface.v1.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, - TLSCertificatesRequiresV3, + TLSCertificatesRequiresV1, generate_csr, generate_private_key, ) @@ -39,7 +39,8 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 6 +LIBPATCH = 5 + logger = logging.getLogger(__name__) @@ -53,7 +54,7 @@ def __init__(self, charm, peer_relation, substrate): self.charm = charm self.substrate = substrate self.peer_relation = peer_relation - self.certs = TLSCertificatesRequiresV3(self.charm, Config.TLS.TLS_PEER_RELATION) + self.certs = TLSCertificatesRequiresV1(self.charm, Config.TLS.TLS_PEER_RELATION) self.framework.observe( self.charm.on.set_tls_private_key_action, self._on_set_tls_private_key ) @@ -69,7 +70,7 @@ def __init__(self, charm, peer_relation, substrate): self.framework.observe(self.certs.on.certificate_expiring, self._on_certificate_expiring) def is_tls_enabled(self, scope: Scopes): - """Returns a boolean indicating if TLS for a given `scope` is enabled.""" + """Getting internal TLS flag (meaning).""" return self.charm.get_secret(scope, Config.TLS.SECRET_CERT_LABEL) is not None def _on_set_tls_private_key(self, event: ActionEvent) -> None: @@ -182,17 +183,18 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: self.charm.set_secret(scope, Config.TLS.SECRET_CERT_LABEL, event.certificate) self.charm.set_secret(scope, Config.TLS.SECRET_CA_LABEL, event.ca) - self.charm.unit.status = MaintenanceStatus("enabling TLS") if self._waiting_for_certs(): logger.debug( - "Return till both internal and external TLS certificates available to avoid second restart." + "Defer till both internal and external TLS certificates available to avoid second restart." ) + event.defer() return logger.info("Restarting mongod with TLS enabled.") self.charm.delete_tls_certificate_from_workload() self.charm.push_tls_certificate_to_workload() + self.charm.unit.status = MaintenanceStatus("enabling TLS") self.charm.restart_mongod_service() self.charm.unit.status = ActiveStatus() diff --git a/lib/charms/tls_certificates_interface/v3/tls_certificates.py b/lib/charms/tls_certificates_interface/v1/tls_certificates.py similarity index 53% rename from lib/charms/tls_certificates_interface/v3/tls_certificates.py rename to lib/charms/tls_certificates_interface/v1/tls_certificates.py index 6fa263973..be171d8e9 100644 --- a/lib/charms/tls_certificates_interface/v3/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v1/tls_certificates.py @@ -1,4 +1,4 @@ -# Copyright 2024 Canonical Ltd. +# Copyright 2021 Canonical Ltd. # See LICENSE file for licensing details. @@ -7,19 +7,16 @@ This library contains the Requires and Provides classes for handling the tls-certificates interface. -Pre-requisites: - - Juju >= 3.0 - ## Getting Started From a charm directory, fetch the library using `charmcraft`: ```shell -charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v1.tls_certificates ``` Add the following libraries to the charm's `requirements.txt` file: - jsonschema -- cryptography >= 42.0.0 +- cryptography Add the following section to the charm's `charmcraft.yaml` file: ```yaml @@ -39,10 +36,10 @@ Example: ```python -from charms.tls_certificates_interface.v3.tls_certificates import ( +from charms.tls_certificates_interface.v1.tls_certificates import ( CertificateCreationRequestEvent, CertificateRevocationRequestEvent, - TLSCertificatesProvidesV3, + TLSCertificatesProvidesV1, generate_private_key, ) from ops.charm import CharmBase, InstallEvent @@ -62,14 +59,12 @@ class ExampleProviderCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - self.certificates = TLSCertificatesProvidesV3(self, "certificates") + self.certificates = TLSCertificatesProvidesV1(self, "certificates") self.framework.observe( - self.certificates.on.certificate_request, - self._on_certificate_request + self.certificates.on.certificate_request, self._on_certificate_request ) self.framework.observe( - self.certificates.on.certificate_revocation_request, - self._on_certificate_revocation_request + self.certificates.on.certificate_revoked, self._on_certificate_revocation_request ) self.framework.observe(self.on.install, self._on_install) @@ -111,7 +106,6 @@ def _on_certificate_request(self, event: CertificateCreationRequestEvent) -> Non ca=ca_certificate, chain=[ca_certificate, certificate], relation_id=event.relation_id, - recommended_expiry_notification_time=720, ) def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: @@ -130,18 +124,17 @@ def _on_certificate_revocation_request(self, event: CertificateRevocationRequest Example: ```python -from charms.tls_certificates_interface.v3.tls_certificates import ( +from charms.tls_certificates_interface.v1.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, CertificateRevokedEvent, - TLSCertificatesRequiresV3, + TLSCertificatesRequiresV1, generate_csr, generate_private_key, ) -from ops.charm import CharmBase, RelationCreatedEvent +from ops.charm import CharmBase, RelationJoinedEvent from ops.main import main from ops.model import ActiveStatus, WaitingStatus -from typing import Union class ExampleRequirerCharm(CharmBase): @@ -149,10 +142,10 @@ class ExampleRequirerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV3(self, "certificates") + self.certificates = TLSCertificatesRequiresV1(self, "certificates") self.framework.observe(self.on.install, self._on_install) self.framework.observe( - self.on.certificates_relation_created, self._on_certificates_relation_created + self.on.certificates_relation_joined, self._on_certificates_relation_joined ) self.framework.observe( self.certificates.on.certificate_available, self._on_certificate_available @@ -161,11 +154,7 @@ def __init__(self, *args): self.certificates.on.certificate_expiring, self._on_certificate_expiring ) self.framework.observe( - self.certificates.on.certificate_invalidated, self._on_certificate_invalidated - ) - self.framework.observe( - self.certificates.on.all_certificates_invalidated, - self._on_all_certificates_invalidated + self.certificates.on.certificate_revoked, self._on_certificate_revoked ) def _on_install(self, event) -> None: @@ -180,7 +169,7 @@ def _on_install(self, event) -> None: {"private_key_password": "banana", "private_key": private_key.decode()} ) - def _on_certificates_relation_created(self, event: RelationCreatedEvent) -> None: + def _on_certificates_relation_joined(self, event: RelationJoinedEvent) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -207,9 +196,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: replicas_relation.data[self.app].update({"chain": event.chain}) self.unit.status = ActiveStatus() - def _on_certificate_expiring( - self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] - ) -> None: + def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -229,7 +216,12 @@ def _on_certificate_expiring( ) replicas_relation.data[self.app].update({"csr": new_csr.decode()}) - def _certificate_revoked(self) -> None: + def _on_certificate_revoked(self, event: CertificateRevokedEvent) -> None: + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return old_csr = replicas_relation.data[self.app].get("csr") private_key_password = replicas_relation.data[self.app].get("private_key_password") private_key = replicas_relation.data[self.app].get("private_key") @@ -248,82 +240,44 @@ def _certificate_revoked(self) -> None: replicas_relation.data[self.app].pop("chain") self.unit.status = WaitingStatus("Waiting for new certificate") - def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> None: - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return - if event.reason == "revoked": - self._certificate_revoked() - if event.reason == "expired": - self._on_certificate_expiring(event) - - def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEvent) -> None: - # Do what you want with this information, probably remove all certificates. - pass - if __name__ == "__main__": main(ExampleRequirerCharm) ``` - -You can relate both charms by running: - -```bash -juju relate -``` - """ # noqa: D405, D410, D411, D214, D416 import copy import json import logging import uuid -from contextlib import suppress -from dataclasses import dataclass -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta from ipaddress import IPv4Address -from typing import List, Literal, Optional, Union +from typing import Dict, List, Optional from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa -from jsonschema import exceptions, validate -from ops.charm import ( - CharmBase, - CharmEvents, - RelationBrokenEvent, - RelationChangedEvent, - SecretExpiredEvent, -) +from cryptography.hazmat.primitives.serialization import pkcs12 +from cryptography.x509.extensions import Extension, ExtensionNotFound +from jsonschema import exceptions, validate # type: ignore[import] +from ops.charm import CharmBase, CharmEvents, RelationChangedEvent, UpdateStatusEvent from ops.framework import EventBase, EventSource, Handle, Object -from ops.jujuversion import JujuVersion -from ops.model import ( - Application, - ModelError, - Relation, - RelationDataContent, - SecretNotFoundError, - Unit, -) # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" # Increment this major API version when introducing breaking changes -LIBAPI = 3 +LIBAPI = 1 # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 13 +LIBPATCH = 12 -PYDEPS = ["cryptography", "jsonschema"] REQUIRER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json", + "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v1/schemas/requirer.json", # noqa: E501 "type": "object", "title": "`tls_certificates` requirer root schema", "description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501 @@ -344,10 +298,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven "type": "array", "items": { "type": "object", - "properties": { - "certificate_signing_request": {"type": "string"}, - "ca": {"type": "boolean"}, - }, + "properties": {"certificate_signing_request": {"type": "string"}}, "required": ["certificate_signing_request"], }, } @@ -358,7 +309,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven PROVIDER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json", + "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v1/schemas/provider.json", # noqa: E501 "type": "object", "title": "`tls_certificates` provider root schema", "description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501 @@ -432,58 +383,6 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven logger = logging.getLogger(__name__) -@dataclass -class RequirerCSR: - """This class represents a certificate signing request from an interface Requirer.""" - - relation_id: int - application_name: str - unit_name: str - csr: str - is_ca: bool - - -@dataclass -class ProviderCertificate: - """This class represents a certificate from an interface Provider.""" - - relation_id: int - application_name: str - csr: str - certificate: str - ca: str - chain: List[str] - revoked: bool - expiry_time: datetime - expiry_notification_time: Optional[datetime] = None - - def chain_as_pem(self) -> str: - """Return full certificate chain as a PEM string.""" - return "\n\n".join(reversed(self.chain)) - - def to_json(self) -> str: - """Return the object as a JSON string. - - Returns: - str: JSON representation of the object - """ - return json.dumps( - { - "relation_id": self.relation_id, - "application_name": self.application_name, - "csr": self.csr, - "certificate": self.certificate, - "ca": self.ca, - "chain": self.chain, - "revoked": self.revoked, - "expiry_time": self.expiry_time.isoformat(), - "expiry_notification_time": self.expiry_notification_time.isoformat() - if self.expiry_notification_time - else None, - } - ) - - class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -502,7 +401,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Return snapshot.""" + """Returns snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -511,16 +410,12 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restore snapshot.""" + """Restores snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] - def chain_as_pem(self) -> str: - """Return full certificate chain as a PEM string.""" - return "\n\n".join(reversed(self.chain)) - class CertificateExpiringEvent(EventBase): """Charm Event triggered when a TLS certificate is almost expired.""" @@ -531,7 +426,7 @@ def __init__(self, handle, certificate: str, expiry: str): Args: handle (Handle): Juju framework handle certificate (str): TLS Certificate - expiry (str): Datetime string representing the time at which the certificate + expiry (str): Datetime string reprensenting the time at which the certificate won't be valid anymore. """ super().__init__(handle) @@ -539,96 +434,88 @@ def __init__(self, handle, certificate: str, expiry: str): self.expiry = expiry def snapshot(self) -> dict: - """Return snapshot.""" + """Returns snapshot.""" return {"certificate": self.certificate, "expiry": self.expiry} def restore(self, snapshot: dict): - """Restore snapshot.""" + """Restores snapshot.""" self.certificate = snapshot["certificate"] self.expiry = snapshot["expiry"] -class CertificateInvalidatedEvent(EventBase): - """Charm Event triggered when a TLS certificate is invalidated.""" +class CertificateExpiredEvent(EventBase): + """Charm Event triggered when a TLS certificate is expired.""" + + def __init__(self, handle: Handle, certificate: str): + super().__init__(handle) + self.certificate = certificate + + def snapshot(self) -> dict: + """Returns snapshot.""" + return {"certificate": self.certificate} + + def restore(self, snapshot: dict): + """Restores snapshot.""" + self.certificate = snapshot["certificate"] + + +class CertificateRevokedEvent(EventBase): + """Charm Event triggered when a TLS certificate is revoked.""" def __init__( self, handle: Handle, - reason: Literal["expired", "revoked"], certificate: str, certificate_signing_request: str, ca: str, chain: List[str], + revoked: bool, ): super().__init__(handle) - self.reason = reason - self.certificate_signing_request = certificate_signing_request self.certificate = certificate + self.certificate_signing_request = certificate_signing_request self.ca = ca self.chain = chain + self.revoked = revoked def snapshot(self) -> dict: - """Return snapshot.""" + """Returns snapshot.""" return { - "reason": self.reason, - "certificate_signing_request": self.certificate_signing_request, "certificate": self.certificate, + "certificate_signing_request": self.certificate_signing_request, "ca": self.ca, "chain": self.chain, + "revoked": self.revoked, } def restore(self, snapshot: dict): - """Restore snapshot.""" - self.reason = snapshot["reason"] - self.certificate_signing_request = snapshot["certificate_signing_request"] + """Restores snapshot.""" self.certificate = snapshot["certificate"] + self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] - - -class AllCertificatesInvalidatedEvent(EventBase): - """Charm Event triggered when all TLS certificates are invalidated.""" - - def __init__(self, handle: Handle): - super().__init__(handle) - - def snapshot(self) -> dict: - """Return snapshot.""" - return {} - - def restore(self, snapshot: dict): - """Restore snapshot.""" - pass + self.revoked = snapshot["revoked"] class CertificateCreationRequestEvent(EventBase): """Charm Event triggered when a TLS certificate is required.""" - def __init__( - self, - handle: Handle, - certificate_signing_request: str, - relation_id: int, - is_ca: bool = False, - ): + def __init__(self, handle: Handle, certificate_signing_request: str, relation_id: int): super().__init__(handle) self.certificate_signing_request = certificate_signing_request self.relation_id = relation_id - self.is_ca = is_ca def snapshot(self) -> dict: - """Return snapshot.""" + """Returns snapshot.""" return { "certificate_signing_request": self.certificate_signing_request, "relation_id": self.relation_id, - "is_ca": self.is_ca, } def restore(self, snapshot: dict): - """Restore snapshot.""" + """Restores snapshot.""" self.certificate_signing_request = snapshot["certificate_signing_request"] self.relation_id = snapshot["relation_id"] - self.is_ca = snapshot["is_ca"] class CertificateRevocationRequestEvent(EventBase): @@ -649,7 +536,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Return snapshot.""" + """Returns snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -658,100 +545,33 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restore snapshot.""" + """Restores snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] -def _load_relation_data(relation_data_content: RelationDataContent) -> dict: - """Load relation data from the relation data bag. +def _load_relation_data(raw_relation_data: dict) -> dict: + """Loads relation data from the relation data bag. Json loads all data. Args: - relation_data_content: Relation data from the databag + raw_relation_data: Relation data from the databag Returns: dict: Relation data in dict format. """ - certificate_data = {} - try: - for key in relation_data_content: - try: - certificate_data[key] = json.loads(relation_data_content[key]) - except (json.decoder.JSONDecodeError, TypeError): - certificate_data[key] = relation_data_content[key] - except ModelError: - pass + certificate_data = dict() + for key in raw_relation_data: + try: + certificate_data[key] = json.loads(raw_relation_data[key]) + except (json.decoder.JSONDecodeError, TypeError): + certificate_data[key] = raw_relation_data[key] return certificate_data -def _get_closest_future_time( - expiry_notification_time: datetime, expiry_time: datetime -) -> datetime: - """Return expiry_notification_time if not in the past, otherwise return expiry_time. - - Args: - expiry_notification_time (datetime): Notification time of impending expiration - expiry_time (datetime): Expiration time - - Returns: - datetime: expiry_notification_time if not in the past, expiry_time otherwise - """ - return ( - expiry_notification_time - if datetime.now(timezone.utc) < expiry_notification_time - else expiry_time - ) - - -def calculate_expiry_notification_time( - validity_start_time: datetime, - expiry_time: datetime, - provider_recommended_notification_time: Optional[int], - requirer_recommended_notification_time: Optional[int], -) -> datetime: - """Calculate a reasonable time to notify the user about the certificate expiry. - - It takes into account the time recommended by the provider and by the requirer. - Time recommended by the provider is preferred, - then time recommended by the requirer, - then dynamically calculated time. - - Args: - validity_start_time: Certificate validity time - expiry_time: Certificate expiry time - provider_recommended_notification_time: - Time in hours prior to expiry to notify the user. - Recommended by the provider. - requirer_recommended_notification_time: - Time in hours prior to expiry to notify the user. - Recommended by the requirer. - - Returns: - datetime: Time to notify the user about the certificate expiry. - """ - if provider_recommended_notification_time is not None: - provider_recommended_notification_time = abs(provider_recommended_notification_time) - provider_recommendation_time_delta = ( - expiry_time - timedelta(hours=provider_recommended_notification_time) - ) - if validity_start_time < provider_recommendation_time_delta: - return provider_recommendation_time_delta - - if requirer_recommended_notification_time is not None: - requirer_recommended_notification_time = abs(requirer_recommended_notification_time) - requirer_recommendation_time_delta = ( - expiry_time - timedelta(hours=requirer_recommended_notification_time) - ) - if validity_start_time < requirer_recommendation_time_delta: - return requirer_recommendation_time_delta - calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) - return expiry_time - timedelta(hours=calculated_hours) - - def generate_ca( private_key: bytes, subject: str, @@ -759,11 +579,11 @@ def generate_ca( validity: int = 365, country: str = "US", ) -> bytes: - """Generate a CA Certificate. + """Generates a CA Certificate. Args: private_key (bytes): Private key - subject (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). + subject (str): Certificate subject private_key_password (bytes): Private key password validity (int): Certificate validity time (in days) country (str): Certificate Issuing country @@ -774,7 +594,7 @@ def generate_ca( private_key_object = serialization.load_pem_private_key( private_key, password=private_key_password ) - subject_name = x509.Name( + subject = issuer = x509.Name( [ x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country), x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), @@ -784,25 +604,14 @@ def generate_ca( private_key_object.public_key() # type: ignore[arg-type] ) subject_identifier = key_identifier = subject_identifier_object.public_bytes() - key_usage = x509.KeyUsage( - digital_signature=True, - key_encipherment=True, - key_cert_sign=True, - key_agreement=False, - content_commitment=False, - data_encipherment=False, - crl_sign=False, - encipher_only=False, - decipher_only=False, - ) cert = ( x509.CertificateBuilder() - .subject_name(subject_name) - .issuer_name(subject_name) + .subject_name(subject) + .issuer_name(issuer) .public_key(private_key_object.public_key()) # type: ignore[arg-type] .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=validity)) .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) .add_extension( x509.AuthorityKeyIdentifier( @@ -812,7 +621,6 @@ def generate_ca( ), critical=False, ) - .add_extension(key_usage, critical=True) .add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True, @@ -822,105 +630,6 @@ def generate_ca( return cert.public_bytes(serialization.Encoding.PEM) -def get_certificate_extensions( - authority_key_identifier: bytes, - csr: x509.CertificateSigningRequest, - alt_names: Optional[List[str]], - is_ca: bool, -) -> List[x509.Extension]: - """Generate a list of certificate extensions from a CSR and other known information. - - Args: - authority_key_identifier (bytes): Authority key identifier - csr (x509.CertificateSigningRequest): CSR - alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR - is_ca (bool): Whether the certificate is a CA certificate - - Returns: - List[x509.Extension]: List of extensions - """ - cert_extensions_list: List[x509.Extension] = [ - x509.Extension( - oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, - value=x509.AuthorityKeyIdentifier( - key_identifier=authority_key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ), - x509.Extension( - oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, - value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), - critical=False, - ), - x509.Extension( - oid=ExtensionOID.BASIC_CONSTRAINTS, - critical=True, - value=x509.BasicConstraints(ca=is_ca, path_length=None), - ), - ] - - sans: List[x509.GeneralName] = [] - san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else [] - sans.extend(san_alt_names) - try: - loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) - sans.extend( - [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] - ) - sans.extend( - [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] - ) - sans.extend( - [ - x509.RegisteredID(oid) - for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) - ] - ) - except x509.ExtensionNotFound: - pass - - if sans: - cert_extensions_list.append( - x509.Extension( - oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - critical=False, - value=x509.SubjectAlternativeName(sans), - ) - ) - - if is_ca: - cert_extensions_list.append( - x509.Extension( - ExtensionOID.KEY_USAGE, - critical=True, - value=x509.KeyUsage( - digital_signature=False, - content_commitment=False, - key_encipherment=False, - data_encipherment=False, - key_agreement=False, - key_cert_sign=True, - crl_sign=True, - encipher_only=False, - decipher_only=False, - ), - ) - ) - - existing_oids = {ext.oid for ext in cert_extensions_list} - for extension in csr.extensions: - if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: - continue - if extension.oid in existing_oids: - logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) - continue - cert_extensions_list.append(extension) - - return cert_extensions_list - - def generate_certificate( csr: bytes, ca: bytes, @@ -928,9 +637,8 @@ def generate_certificate( ca_key_password: Optional[bytes] = None, validity: int = 365, alt_names: Optional[List[str]] = None, - is_ca: bool = False, ) -> bytes: - """Generate a TLS certificate based on a CSR. + """Generates a TLS certificate based on a CSR. Args: csr (bytes): CSR @@ -939,15 +647,13 @@ def generate_certificate( ca_key_password: CA private key password validity (int): Certificate validity (in days) alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR - is_ca (bool): Whether the certificate is a CA certificate Returns: bytes: Certificate """ csr_object = x509.load_pem_x509_csr(csr) subject = csr_object.subject - ca_pem = x509.load_pem_x509_certificate(ca) - issuer = ca_pem.issuer + issuer = x509.load_pem_x509_certificate(ca).issuer private_key = serialization.load_pem_private_key(ca_key, password=ca_key_password) certificate_builder = ( @@ -956,36 +662,81 @@ def generate_certificate( .issuer_name(issuer) .public_key(csr_object.public_key()) .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=validity)) ) - extensions = get_certificate_extensions( - authority_key_identifier=ca_pem.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value.key_identifier, - csr=csr_object, - alt_names=alt_names, - is_ca=is_ca, - ) - for extension in extensions: + + extensions_list = csr_object.extensions + san_ext: Optional[x509.Extension] = None + if alt_names: + full_sans_dns = alt_names.copy() try: - certificate_builder = certificate_builder.add_extension( - extval=extension.value, - critical=extension.critical, + loaded_san_ext = csr_object.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ) + full_sans_dns.extend(loaded_san_ext.value.get_values_for_type(x509.DNSName)) + except ExtensionNotFound: + pass + finally: + san_ext = Extension( + ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + False, + x509.SubjectAlternativeName([x509.DNSName(name) for name in full_sans_dns]), ) - except ValueError as e: - logger.warning("Failed to add extension %s: %s", extension.oid, e) + if not extensions_list: + extensions_list = x509.Extensions([san_ext]) + for extension in extensions_list: + if extension.value.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME and san_ext: + extension = san_ext + + certificate_builder = certificate_builder.add_extension( + extension.value, + critical=extension.critical, + ) + certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM) +def generate_pfx_package( + certificate: bytes, + private_key: bytes, + package_password: str, + private_key_password: Optional[bytes] = None, +) -> bytes: + """Generates a PFX package to contain the TLS certificate and private key. + + Args: + certificate (bytes): TLS certificate + private_key (bytes): Private key + package_password (str): Password to open the PFX package + private_key_password (bytes): Private key password + + Returns: + bytes: + """ + private_key_object = serialization.load_pem_private_key( + private_key, password=private_key_password + ) + certificate_object = x509.load_pem_x509_certificate(certificate) + name = certificate_object.subject.rfc4514_string() + pfx_bytes = pkcs12.serialize_key_and_certificates( + name=name.encode(), + cert=certificate_object, + key=private_key_object, # type: ignore[arg-type] + cas=None, + encryption_algorithm=serialization.BestAvailableEncryption(package_password.encode()), + ) + return pfx_bytes + + def generate_private_key( password: Optional[bytes] = None, key_size: int = 2048, public_exponent: int = 65537, ) -> bytes: - """Generate a private key. + """Generates a private key. Args: password (bytes): Password for decrypting the private key @@ -1002,24 +753,20 @@ def generate_private_key( key_bytes = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=( - serialization.BestAvailableEncryption(password) - if password - else serialization.NoEncryption() - ), + encryption_algorithm=serialization.BestAvailableEncryption(password) + if password + else serialization.NoEncryption(), ) return key_bytes -def generate_csr( # noqa: C901 +def generate_csr( private_key: bytes, subject: str, add_unique_id_to_subject_name: bool = True, organization: Optional[str] = None, email_address: Optional[str] = None, country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, private_key_password: Optional[bytes] = None, sans: Optional[List[str]] = None, sans_oid: Optional[List[str]] = None, @@ -1027,26 +774,24 @@ def generate_csr( # noqa: C901 sans_dns: Optional[List[str]] = None, additional_critical_extensions: Optional[List] = None, ) -> bytes: - """Generate a CSR using private key and subject. + """Generates a CSR using private key and subject. Args: private_key (bytes): Private key - subject (str): CSR Common Name that can be an IP or a Full Qualified Domain Name (FQDN). + subject (str): CSR Subject. add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's subject name. Always leave to "True" when the CSR is used to request certificates using the tls-certificates relation. organization (str): Name of organization. email_address (str): Email address. country_name (str): Country Name. - state_or_province_name (str): State or Province Name. - locality_name (str): Locality Name. private_key_password (bytes): Private key password sans (list): Use sans_dns - this will be deprecated in a future release List of DNS subject alternative names (keeping it for now for backward compatibility) sans_oid (list): List of registered ID SANs sans_dns (list): List of DNS subject alternative names (similar to the arg: sans) sans_ip (list): List of IP subject alternative names - additional_critical_extensions (list): List of critical additional extension objects. + additional_critical_extensions (list): List if critical additional extension objects. Object must be a x509 ExtensionType. Returns: @@ -1065,12 +810,6 @@ def generate_csr( # noqa: C901 subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) if country_name: subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) - if state_or_province_name: - subject_name.append( - x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) - ) - if locality_name: - subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) _sans: List[x509.GeneralName] = [] @@ -1093,59 +832,6 @@ def generate_csr( # noqa: C901 return signed_certificate.public_bytes(serialization.Encoding.PEM) -def csr_matches_certificate(csr: str, cert: str) -> bool: - """Check if a CSR matches a certificate. - - Args: - csr (str): Certificate Signing Request as a string - cert (str): Certificate as a string - Returns: - bool: True/False depending on whether the CSR matches the certificate. - """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if ( - csr_object.public_key().public_numbers().n # type: ignore[union-attr] - != cert_object.public_key().public_numbers().n # type: ignore[union-attr] - ): - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") - return False - return True - - -def _relation_data_is_valid( - relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict -) -> bool: - """Check whether relation data is valid based on json schema. - - Args: - relation (Relation): Relation object - app_or_unit (Union[Application, Unit]): Application or unit object - json_schema (dict): Json schema - - Returns: - bool: Whether relation data is valid. - """ - relation_data = _load_relation_data(relation.data[app_or_unit]) - try: - validate(instance=relation_data, schema=json_schema) - return True - except exceptions.ValidationError: - return False - - class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -1158,14 +844,14 @@ class CertificatesRequirerCharmEvents(CharmEvents): certificate_available = EventSource(CertificateAvailableEvent) certificate_expiring = EventSource(CertificateExpiringEvent) - certificate_invalidated = EventSource(CertificateInvalidatedEvent) - all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) + certificate_expired = EventSource(CertificateExpiredEvent) + certificate_revoked = EventSource(CertificateRevokedEvent) -class TLSCertificatesProvidesV3(Object): +class TLSCertificatesProvidesV1(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" - on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] + on = CertificatesProviderCharmEvents() def __init__(self, charm: CharmBase, relationship_name: str): super().__init__(charm, relationship_name) @@ -1175,22 +861,6 @@ def __init__(self, charm: CharmBase, relationship_name: str): self.charm = charm self.relationship_name = relationship_name - def _load_app_relation_data(self, relation: Relation) -> dict: - """Load relation data from the application relation data bag. - - Json loads all data. - - Args: - relation: Relation data from the application databag - - Returns: - dict: Relation data in dict format. - """ - # If unit is not leader, it does not try to reach relation data. - if not self.model.unit.is_leader(): - return {} - return _load_relation_data(relation.data[self.charm.app]) - def _add_certificate( self, relation_id: int, @@ -1198,9 +868,8 @@ def _add_certificate( certificate_signing_request: str, ca: str, chain: List[str], - recommended_expiry_notification_time: Optional[int] = None, ) -> None: - """Add certificate to relation data. + """Adds certificate to relation data. Args: relation_id (int): Relation id @@ -1208,8 +877,6 @@ def _add_certificate( certificate_signing_request (str): Certificate Signing Request ca (str): CA Certificate chain (list): CA Chain - recommended_expiry_notification_time (int): - Time in hours before the certificate expires to notify the user. Returns: None @@ -1227,9 +894,8 @@ def _add_certificate( "certificate_signing_request": certificate_signing_request, "ca": ca, "chain": chain, - "recommended_expiry_notification_time": recommended_expiry_notification_time, } - provider_relation_data = self._load_app_relation_data(relation) + provider_relation_data = _load_relation_data(relation.data[self.charm.app]) provider_certificates = provider_relation_data.get("certificates", []) certificates = copy.deepcopy(provider_certificates) if new_certificate in certificates: @@ -1244,7 +910,7 @@ def _remove_certificate( certificate: Optional[str] = None, certificate_signing_request: Optional[str] = None, ) -> None: - """Remove certificate from a given relation based on user provided certificate or csr. + """Removes certificate from a given relation based on user provided certificate or csr. Args: relation_id (int): Relation id @@ -1262,7 +928,7 @@ def _remove_certificate( raise RuntimeError( f"Relation {self.relationship_name} with relation id {relation_id} does not exist" ) - provider_relation_data = self._load_app_relation_data(relation) + provider_relation_data = _load_relation_data(relation.data[self.charm.app]) provider_certificates = provider_relation_data.get("certificates", []) certificates = copy.deepcopy(provider_certificates) for certificate_dict in certificates: @@ -1275,13 +941,29 @@ def _remove_certificate( certificates.remove(certificate_dict) relation.data[self.model.app]["certificates"] = json.dumps(certificates) + @staticmethod + def _relation_data_is_valid(certificates_data: dict) -> bool: + """Uses JSON schema validator to validate relation data content. + + Args: + certificates_data (dict): Certificate data dictionary as retrieved from relation data. + + Returns: + bool: True/False depending on whether the relation data follows the json schema. + """ + try: + validate(instance=certificates_data, schema=REQUIRER_JSON_SCHEMA) + return True + except exceptions.ValidationError: + return False + def revoke_all_certificates(self) -> None: - """Revoke all certificates of this provider. + """Revokes all certificates of this provider. This method is meant to be used when the Root CA has changed. """ for relation in self.model.relations[self.relationship_name]: - provider_relation_data = self._load_app_relation_data(relation) + provider_relation_data = _load_relation_data(relation.data[self.charm.app]) provider_certificates = copy.deepcopy(provider_relation_data.get("certificates", [])) for certificate in provider_certificates: certificate["revoked"] = True @@ -1294,9 +976,8 @@ def set_relation_certificate( ca: str, chain: List[str], relation_id: int, - recommended_expiry_notification_time: Optional[int] = None, ) -> None: - """Add certificates to relation data. + """Adds certificates to relation data. Args: certificate (str): Certificate @@ -1304,14 +985,10 @@ def set_relation_certificate( ca (str): CA Certificate chain (list): CA Chain relation_id (int): Juju relation ID - recommended_expiry_notification_time (int): - Recommended time in hours before the certificate expires to notify the user. Returns: None """ - if not self.model.unit.is_leader(): - return certificates_relation = self.model.get_relation( relation_name=self.relationship_name, relation_id=relation_id ) @@ -1327,11 +1004,10 @@ def set_relation_certificate( certificate_signing_request=certificate_signing_request.strip(), ca=ca.strip(), chain=[cert.strip() for cert in chain], - recommended_expiry_notification_time=recommended_expiry_notification_time, ) def remove_certificate(self, certificate: str) -> None: - """Remove a given certificate from relation data. + """Removes a given certificate from relation data. Args: certificate (str): TLS Certificate @@ -1345,67 +1021,8 @@ def remove_certificate(self, certificate: str) -> None: for certificate_relation in certificates_relation: self._remove_certificate(certificate=certificate, relation_id=certificate_relation.id) - def get_issued_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued (non revoked) certificates. - - Returns: - List: List of ProviderCertificate objects - """ - provider_certificates = self.get_provider_certificates(relation_id=relation_id) - return [certificate for certificate in provider_certificates if not certificate.revoked] - - def get_provider_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued certificates. - - Returns: - List: List of ProviderCertificate objects - """ - certificates: List[ProviderCertificate] = [] - relations = ( - [ - relation - for relation in self.model.relations[self.relationship_name] - if relation.id == relation_id - ] - if relation_id is not None - else self.model.relations.get(self.relationship_name, []) - ) - for relation in relations: - if not relation.app: - logger.warning("Relation %s does not have an application", relation.id) - continue - provider_relation_data = self._load_app_relation_data(relation) - provider_certificates = provider_relation_data.get("certificates", []) - for certificate in provider_certificates: - try: - certificate_object = x509.load_pem_x509_certificate( - data=certificate["certificate"].encode() - ) - except ValueError as e: - logger.error("Could not load certificate - Skipping: %s", e) - continue - provider_certificate = ProviderCertificate( - relation_id=relation.id, - application_name=relation.app.name, - csr=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], - revoked=certificate.get("revoked", False), - expiry_time=certificate_object.not_valid_after_utc, - expiry_notification_time=certificate.get( - "recommended_expiry_notification_time" - ), - ) - certificates.append(provider_certificate) - return certificates - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handle relation changed event. + """Handler triggerred on relation changed event. Looks at the relation data and either emits: - certificate request event: If the unit relation data contains a CSR for which @@ -1419,258 +1036,120 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ - if event.unit is None: - logger.error("Relation_changed event does not have a unit.") - return - if not self.model.unit.is_leader(): - return - if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): - logger.debug("Relation data did not pass JSON Schema validation") + assert event.unit is not None + requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) + provider_relation_data = _load_relation_data(event.relation.data[self.charm.app]) + if not self._relation_data_is_valid(requirer_relation_data): + logger.warning( + f"Relation data did not pass JSON Schema validation: {requirer_relation_data}" + ) return - provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) - requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) + provider_certificates = provider_relation_data.get("certificates", []) + requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) provider_csrs = [ - certificate_creation_request.csr + certificate_creation_request["certificate_signing_request"] for certificate_creation_request in provider_certificates ] - for certificate_request in requirer_csrs: - if certificate_request.csr not in provider_csrs: + requirer_unit_csrs = [ + certificate_creation_request["certificate_signing_request"] + for certificate_creation_request in requirer_csrs + ] + for certificate_signing_request in requirer_unit_csrs: + if certificate_signing_request not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_request.csr, - relation_id=certificate_request.relation_id, - is_ca=certificate_request.is_ca, + certificate_signing_request=certificate_signing_request, + relation_id=event.relation.id, ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: - """Revoke certificates for which no unit has a CSR. - - Goes through all generated certificates and compare against the list of CSRs for all units. - - Returns: - None - """ - provider_certificates = self.get_provider_certificates(relation_id) - requirer_csrs = self.get_requirer_csrs(relation_id) - list_of_csrs = [csr.csr for csr in requirer_csrs] - for certificate in provider_certificates: - if certificate.csr not in list_of_csrs: - self.on.certificate_revocation_request.emit( - certificate=certificate.certificate, - certificate_signing_request=certificate.csr, - ca=certificate.ca, - chain=certificate.chain, - ) - self.remove_certificate(certificate=certificate.certificate) + """Revokes certificates for which no unit has a CSR. - def get_outstanding_certificate_requests( - self, relation_id: Optional[int] = None - ) -> List[RequirerCSR]: - """Return CSR's for which no certificate has been issued. + Goes through all generated certificates and compare agains the list of CSRS for all units + of a given relationship. Args: relation_id (int): Relation id Returns: - list: List of RequirerCSR objects. - """ - requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) - outstanding_csrs: List[RequirerCSR] = [] - for relation_csr in requirer_csrs: - if not self.certificate_issued_for_csr( - app_name=relation_csr.application_name, - csr=relation_csr.csr, - relation_id=relation_id, - ): - outstanding_csrs.append(relation_csr) - return outstanding_csrs - - def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: - """Return a list of requirers' CSRs. - - It returns CSRs from all relations if relation_id is not specified. - CSRs are returned per relation id, application name and unit name. - - Returns: - list: List[RequirerCSR] + None """ - relation_csrs: List[RequirerCSR] = [] - relations = ( - [ - relation - for relation in self.model.relations[self.relationship_name] - if relation.id == relation_id - ] - if relation_id is not None - else self.model.relations.get(self.relationship_name, []) + certificates_relation = self.model.get_relation( + relation_name=self.relationship_name, relation_id=relation_id ) - - for relation in relations: - for unit in relation.units: - requirer_relation_data = _load_relation_data(relation.data[unit]) - unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) - for unit_csr in unit_csrs_list: - csr = unit_csr.get("certificate_signing_request") - if not csr: - logger.warning("No CSR found in relation data - Skipping") - continue - ca = unit_csr.get("ca", False) - if not relation.app: - logger.warning("No remote app in relation - Skipping") - continue - relation_csr = RequirerCSR( - relation_id=relation.id, - application_name=relation.app.name, - unit_name=unit.name, - csr=csr, - is_ca=ca, - ) - relation_csrs.append(relation_csr) - return relation_csrs - - def certificate_issued_for_csr( - self, app_name: str, csr: str, relation_id: Optional[int] - ) -> bool: - """Check whether a certificate has been issued for a given CSR. - - Args: - app_name (str): Application name that the CSR belongs to. - csr (str): Certificate Signing Request. - relation_id (Optional[int]): Relation ID - - Returns: - bool: True/False depending on whether a certificate has been issued for the given CSR. - """ - issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) - for issued_certificate in issued_certificates_per_csr: - if issued_certificate.csr == csr and issued_certificate.application_name == app_name: - return csr_matches_certificate(csr, issued_certificate.certificate) - return False + if not certificates_relation: + raise RuntimeError(f"Relation {self.relationship_name} does not exist") + provider_relation_data = _load_relation_data(certificates_relation.data[self.charm.app]) + list_of_csrs: List[str] = [] + for unit in certificates_relation.units: + requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) + requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) + list_of_csrs.extend(csr["certificate_signing_request"] for csr in requirer_csrs) + provider_certificates = provider_relation_data.get("certificates", []) + for certificate in provider_certificates: + if certificate["certificate_signing_request"] not in list_of_csrs: + self.on.certificate_revocation_request.emit( + certificate=certificate["certificate"], + certificate_signing_request=certificate["certificate_signing_request"], + ca=certificate["ca"], + chain=certificate["chain"], + ) + self.remove_certificate(certificate=certificate["certificate"]) -class TLSCertificatesRequiresV3(Object): +class TLSCertificatesRequiresV1(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" - on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] + on = CertificatesRequirerCharmEvents() def __init__( self, charm: CharmBase, relationship_name: str, - expiry_notification_time: Optional[int] = None, + expiry_notification_time: int = 168, ): - """Generate/use private key and observes relation changed event. + """Generates/use private key and observes relation changed event. Args: charm: Charm object relationship_name: Juju relation name - expiry_notification_time (int): Number of hours prior to certificate expiry. - Used to trigger the CertificateExpiring event. - This value is used as a recommendation only, - The actual value is calculated taking into account the provider's recommendation. + expiry_notification_time (int): Time difference between now and expiry (in hours). + Used to trigger the CertificateExpiring event. Default: 7 days. """ super().__init__(charm, relationship_name) - if not JujuVersion.from_environ().has_secrets: - logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") self.relationship_name = relationship_name self.charm = charm self.expiry_notification_time = expiry_notification_time self.framework.observe( charm.on[relationship_name].relation_changed, self._on_relation_changed ) - self.framework.observe( - charm.on[relationship_name].relation_broken, self._on_relation_broken - ) - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + self.framework.observe(charm.on.update_status, self._on_update_status) - def get_requirer_csrs(self) -> List[RequirerCSR]: - """Return list of requirer's CSRs from relation unit data. - - Returns: - list: List of RequirerCSR objects. - """ + @property + def _requirer_csrs(self) -> List[Dict[str, str]]: + """Returns list of requirer CSR's from relation data.""" relation = self.model.get_relation(self.relationship_name) if not relation: - return [] - requirer_csrs = [] + raise RuntimeError(f"Relation {self.relationship_name} does not exist") requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) - for requirer_csr_dict in requirer_csrs_dict: - csr = requirer_csr_dict.get("certificate_signing_request") - if not csr: - logger.warning("No CSR found in relation data - Skipping") - continue - ca = requirer_csr_dict.get("ca", False) - relation_csr = RequirerCSR( - relation_id=relation.id, - application_name=self.model.app.name, - unit_name=self.model.unit.name, - csr=csr, - is_ca=ca, - ) - requirer_csrs.append(relation_csr) - return requirer_csrs + return requirer_relation_data.get("certificate_signing_requests", []) - def get_provider_certificates(self) -> List[ProviderCertificate]: - """Return list of certificates from the provider's relation data.""" - provider_certificates: List[ProviderCertificate] = [] + @property + def _provider_certificates(self) -> List[Dict[str, str]]: + """Returns list of provider CSR's from relation data.""" relation = self.model.get_relation(self.relationship_name) if not relation: - logger.debug("No relation: %s", self.relationship_name) - return [] + raise RuntimeError(f"Relation {self.relationship_name} does not exist") if not relation.app: - logger.debug("No remote app in relation: %s", self.relationship_name) - return [] + raise RuntimeError(f"Remote app for relation {self.relationship_name} does not exist") provider_relation_data = _load_relation_data(relation.data[relation.app]) - provider_certificate_dicts = provider_relation_data.get("certificates", []) - for provider_certificate_dict in provider_certificate_dicts: - certificate = provider_certificate_dict.get("certificate") - if not certificate: - logger.warning("No certificate found in relation data - Skipping") - continue - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - except ValueError as e: - logger.error("Could not load certificate - Skipping: %s", e) - continue - ca = provider_certificate_dict.get("ca") - chain = provider_certificate_dict.get("chain", []) - csr = provider_certificate_dict.get("certificate_signing_request") - recommended_expiry_notification_time = provider_certificate_dict.get( - "recommended_expiry_notification_time" - ) - expiry_time = certificate_object.not_valid_after_utc - validity_start_time = certificate_object.not_valid_before_utc - expiry_notification_time = calculate_expiry_notification_time( - validity_start_time=validity_start_time, - expiry_time=expiry_time, - provider_recommended_notification_time=recommended_expiry_notification_time, - requirer_recommended_notification_time=self.expiry_notification_time, - ) - if not csr: - logger.warning("No CSR found in relation data - Skipping") - continue - revoked = provider_certificate_dict.get("revoked", False) - provider_certificate = ProviderCertificate( - relation_id=relation.id, - application_name=relation.app.name, - csr=csr, - certificate=certificate, - ca=ca, - chain=chain, - revoked=revoked, - expiry_time=expiry_time, - expiry_notification_time=expiry_notification_time, - ) - provider_certificates.append(provider_certificate) - return provider_certificates + return provider_relation_data.get("certificates", []) - def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: - """Add CSR to relation data. + def _add_requirer_csr(self, csr: str) -> None: + """Adds CSR to relation data. Args: csr (str): Certificate Signing Request - is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1681,24 +1160,16 @@ def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - for requirer_csr in self.get_requirer_csrs(): - if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: - logger.info("CSR already in relation data - Doing nothing") - return - new_csr_dict = { - "certificate_signing_request": csr, - "ca": is_ca, - } - requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) - new_relation_data = copy.deepcopy(existing_relation_data) - new_relation_data.append(new_csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( - new_relation_data - ) + new_csr_dict = {"certificate_signing_request": csr} + if new_csr_dict in self._requirer_csrs: + logger.info("CSR already in relation data - Doing nothing") + return + requirer_csrs = copy.deepcopy(self._requirer_csrs) + requirer_csrs.append(new_csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) - def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: - """Remove CSR from relation data. + def _remove_requirer_csr(self, csr: str) -> None: + """Removes CSR from relation data. Args: csr (str): Certificate signing request @@ -1712,44 +1183,36 @@ def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - if not self.get_requirer_csrs(): - logger.info("No CSRs in relation data - Doing nothing") + requirer_csrs = copy.deepcopy(self._requirer_csrs) + csr_dict = {"certificate_signing_request": csr} + if csr_dict not in requirer_csrs: + logger.info("CSR not in relation data - Doing nothing") return - requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) - new_relation_data = copy.deepcopy(existing_relation_data) - for requirer_csr in new_relation_data: - if requirer_csr["certificate_signing_request"] == csr: - new_relation_data.remove(requirer_csr) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( - new_relation_data - ) + requirer_csrs.remove(csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) - def request_certificate_creation( - self, certificate_signing_request: bytes, is_ca: bool = False - ) -> None: + def request_certificate_creation(self, certificate_signing_request: bytes) -> None: """Request TLS certificate to provider charm. Args: certificate_signing_request (bytes): Certificate Signing Request - is_ca (bool): Whether the certificate is a CA certificate Returns: None """ relation = self.model.get_relation(self.relationship_name) if not relation: - raise RuntimeError( + message = ( f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr_to_relation_data( - certificate_signing_request.decode().strip(), is_ca=is_ca - ) + logger.error(message) + raise RuntimeError(message) + self._add_requirer_csr(certificate_signing_request.decode().strip()) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: - """Remove CSR from relation data. + """Removes CSR from relation data. The provider of this relation is then expected to remove certificates associated to this CSR from the relation data as well and emit a request_certificate_revocation event for the @@ -1761,13 +1224,13 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> Returns: None """ - self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) + self._remove_requirer_csr(certificate_signing_request.decode().strip()) logger.info("Certificate revocation sent to provider") def request_certificate_renewal( self, old_certificate_signing_request: bytes, new_certificate_signing_request: bytes ) -> None: - """Renew certificate. + """Renews certificate. Removes old CSR from relation data and adds new one. @@ -1789,69 +1252,24 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") - def get_assigned_certificates(self) -> List[ProviderCertificate]: - """Get a list of certificates that were assigned to this unit. - - Returns: - List: List[ProviderCertificate] - """ - assigned_certificates = [] - for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): - if cert := self._find_certificate_in_relation_data(requirer_csr.csr): - assigned_certificates.append(cert) - return assigned_certificates - - def get_expiring_certificates(self) -> List[ProviderCertificate]: - """Get a list of certificates that were assigned to this unit that are expiring or expired. - - Returns: - List: List[ProviderCertificate] - """ - expiring_certificates: List[ProviderCertificate] = [] - for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): - if cert := self._find_certificate_in_relation_data(requirer_csr.csr): - if not cert.expiry_time or not cert.expiry_notification_time: - continue - if datetime.now(timezone.utc) > cert.expiry_notification_time: - expiring_certificates.append(cert) - return expiring_certificates - - def get_certificate_signing_requests( - self, - fulfilled_only: bool = False, - unfulfilled_only: bool = False, - ) -> List[RequirerCSR]: - """Get the list of CSR's that were sent to the provider. - - You can choose to get only the CSR's that have a certificate assigned or only the CSR's - that don't. + @staticmethod + def _relation_data_is_valid(certificates_data: dict) -> bool: + """Checks whether relation data is valid based on json schema. Args: - fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. - unfulfilled_only (bool): This option will discard CSRs that have certificates signed. + certificates_data: Certificate data in dict format. Returns: - List of RequirerCSR objects. + bool: Whether relation data is valid. """ - csrs = [] - for requirer_csr in self.get_requirer_csrs(): - cert = self._find_certificate_in_relation_data(requirer_csr.csr) - if (unfulfilled_only and cert) or (fulfilled_only and not cert): - continue - csrs.append(requirer_csr) - - return csrs + try: + validate(instance=certificates_data, schema=PROVIDER_JSON_SCHEMA) + return True + except exceptions.ValidationError: + return False def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handle relation changed event. - - Goes through all providers certificates that match a requested CSR. - - If the provider certificate is revoked, emit a CertificateInvalidateEvent, - otherwise emit a CertificateAvailableEvent. - - Remove the secret for revoked certificate, or add a secret with the correct expiry - time for new certificates. + """Handler triggered on relation changed events. Args: event: Juju event @@ -1859,141 +1277,84 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ - if not event.app: - logger.warning("No remote app in relation - Skipping") + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.warning(f"No relation: {self.relationship_name}") + return + if not relation.app: + logger.warning(f"No remote app in relation: {self.relationship_name}") return - if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): - logger.debug("Relation data did not pass JSON Schema validation") + provider_relation_data = _load_relation_data(relation.data[relation.app]) + if not self._relation_data_is_valid(provider_relation_data): + logger.warning( + f"Provider relation data did not pass JSON Schema validation: " + f"{event.relation.data[relation.app]}" + ) return - provider_certificates = self.get_provider_certificates() requirer_csrs = [ - certificate_creation_request.csr - for certificate_creation_request in self.get_requirer_csrs() + certificate_creation_request["certificate_signing_request"] + for certificate_creation_request in self._requirer_csrs ] - for certificate in provider_certificates: - if certificate.csr in requirer_csrs: - if certificate.revoked: - with suppress(SecretNotFoundError): - secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") - secret.remove_all_revisions() - self.on.certificate_invalidated.emit( - reason="revoked", - certificate=certificate.certificate, - certificate_signing_request=certificate.csr, - ca=certificate.ca, - chain=certificate.chain, + for certificate in self._provider_certificates: + if certificate["certificate_signing_request"] in requirer_csrs: + if certificate.get("revoked", False): + self.on.certificate_revoked.emit( + certificate_signing_request=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], + revoked=True, ) else: - try: - secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") - secret.set_content({"certificate": certificate.certificate}) - secret.set_info( - expire=self._get_next_secret_expiry_time(certificate), - ) - except SecretNotFoundError: - logger.debug("Adding secret with label %s", f"{LIBID}-{certificate.csr}") - secret = self.charm.unit.add_secret( - {"certificate": certificate.certificate}, - label=f"{LIBID}-{certificate.csr}", - expire=self._get_next_secret_expiry_time(certificate), - ) self.on.certificate_available.emit( - certificate_signing_request=certificate.csr, - certificate=certificate.certificate, - ca=certificate.ca, - chain=certificate.chain, + certificate_signing_request=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], ) - def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: - """Return the expiry time or expiry notification time. - - Extracts the expiry time from the provided certificate, calculates the - expiry notification time and return the closest of the two, that is in - the future. - - Args: - certificate: ProviderCertificate object - - Returns: - Optional[datetime]: None if the certificate expiry time cannot be read, - next expiry time otherwise. - """ - if not certificate.expiry_time or not certificate.expiry_notification_time: - return None - return _get_closest_future_time( - certificate.expiry_notification_time, - certificate.expiry_time, - ) + def _on_update_status(self, event: UpdateStatusEvent) -> None: + """Triggered on update status event. - def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handle Relation Broken Event. - - Emitting `all_certificates_invalidated` from `relation-broken` rather - than `relation-departed` since certs are stored in app data. + Goes through each certificate in the "certificates" relation and checks their expiry date. + If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if + they are expired, emits a CertificateExpiredEvent. Args: - event: Juju event + event (UpdateStatusEvent): Juju event Returns: None """ - self.on.all_certificates_invalidated.emit() - - def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Handle Secret Expired Event. - - Loads the certificate from the secret, and will emit 1 of 2 - events. - - If the certificate is not yet expired, emits CertificateExpiringEvent - and updates the expiry time of the secret to the exact expiry time on - the certificate. - - If the certificate is expired, emits CertificateInvalidedEvent and - deletes the secret. - - Args: - event (SecretExpiredEvent): Juju event - """ - if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): - return - csr = event.secret.label[len(f"{LIBID}-") :] - provider_certificate = self._find_certificate_in_relation_data(csr) - if not provider_certificate: - # A secret expired but we did not find matching certificate. Cleaning up - event.secret.remove_all_revisions() + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.warning(f"No relation: {self.relationship_name}") return - - if not provider_certificate.expiry_time: - # A secret expired but matching certificate is invalid. Cleaning up - event.secret.remove_all_revisions() + if not relation.app: + logger.warning(f"No remote app in relation: {self.relationship_name}") return - - if datetime.now(timezone.utc) < provider_certificate.expiry_time: - logger.warning("Certificate almost expired") - self.on.certificate_expiring.emit( - certificate=provider_certificate.certificate, - expiry=provider_certificate.expiry_time.isoformat(), - ) - event.secret.set_info( - expire=provider_certificate.expiry_time, - ) - else: - logger.warning("Certificate is expired") - self.on.certificate_invalidated.emit( - reason="expired", - certificate=provider_certificate.certificate, - certificate_signing_request=provider_certificate.csr, - ca=provider_certificate.ca, - chain=provider_certificate.chain, + provider_relation_data = _load_relation_data(relation.data[relation.app]) + if not self._relation_data_is_valid(provider_relation_data): + logger.warning( + f"Provider relation data did not pass JSON Schema validation: " + f"{relation.data[relation.app]}" ) - self.request_certificate_revocation(provider_certificate.certificate.encode()) - event.secret.remove_all_revisions() - - def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: - """Return the certificate that match the given CSR.""" - for provider_certificate in self.get_provider_certificates(): - if provider_certificate.csr != csr: + return + for certificate_dict in self._provider_certificates: + certificate = certificate_dict["certificate"] + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError: + logger.warning("Could not load certificate.") continue - return provider_certificate - return None + time_difference = certificate_object.not_valid_after - datetime.utcnow() + if time_difference.total_seconds() < 0: + logger.warning("Certificate is expired") + self.on.certificate_expired.emit(certificate=certificate) + self.request_certificate_revocation(certificate.encode()) + continue + if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): + logger.warning("Certificate almost expired") + self.on.certificate_expiring.emit( + certificate=certificate, expiry=certificate_object.not_valid_after.isoformat() + )