diff --git a/.github/workflows/integration_test.yaml b/.github/workflows/integration_test.yaml index 7fa64e2c..763d1932 100644 --- a/.github/workflows/integration_test.yaml +++ b/.github/workflows/integration_test.yaml @@ -8,7 +8,7 @@ jobs: uses: canonical/operator-workflows/.github/workflows/integration_test.yaml@main secrets: inherit with: - extra-arguments: --localstack-address 172.17.0.1 -m "not (requires_secrets)" + extra-arguments: --localstack-address 172.17.0.1 pre-run-script: localstack-installation.sh trivy-image-config: "trivy.yaml" juju-channel: 3.1/stable diff --git a/.github/workflows/integration_test_with_secrets.yaml b/.github/workflows/integration_test_with_secrets.yaml deleted file mode 100644 index fa745931..00000000 --- a/.github/workflows/integration_test_with_secrets.yaml +++ /dev/null @@ -1,15 +0,0 @@ -name: Integration tests (require secrets) - -on: - pull_request: - -jobs: - integration-tests-with-secrets: - uses: canonical/operator-workflows/.github/workflows/integration_test.yaml@main - secrets: inherit - with: - extra-arguments: --localstack-address 172.17.0.1 -m "requires_secrets" - pre-run-script: localstack-installation.sh - trivy-image-config: "trivy.yaml" - juju-channel: 3.1/stable - channel: 1.28-strict/stable diff --git a/config.yaml b/config.yaml index a23dd440..8cf75602 100644 --- a/config.yaml +++ b/config.yaml @@ -25,10 +25,6 @@ options: type: string description: "Comma-separated list of groups to sync from SAML provider." default: "" - saml_target_url: - type: string - description: "SAML authentication target url." - default: "" smtp_address: type: string description: "Hostname / IP that should be used to send SMTP mail." diff --git a/docs/how-to/configure-saml.md b/docs/how-to/configure-saml.md index 93a5155f..aebd760a 100644 --- a/docs/how-to/configure-saml.md +++ b/docs/how-to/configure-saml.md @@ -2,7 +2,15 @@ To configure Discourse's SAML integration you'll have to set the following configuration options with the appropriate values for your SAML server by running `juju config [charm_name] [configuration]=[value]`. -The SAML URL needs to be scpecified in `saml_target_url`. If you wish to force the login to go through SAML, enable `force_saml_login`. +If you wish to force the login to go through SAML, enable `force_saml_login`. The groups to be synced from the provider can be defined in `saml_sync_groups` as a comma-separated list of values. +In order to implement the relation discourse has to be related with the [saml-integrator](https://charmhub.io/saml-integrator): +``` +juju deploy saml-integrator --channel=edge +# Set the SAML integrator configs +juju config saml-integrator metadata_url=https://login.staging.ubuntu.com/saml/metadata +juju config saml-integrator entity_id=https://login.staging.ubuntu.com +juju integrate discourse-k8s saml-integrator +``` For more details on the configuration options and their default values see the [configuration reference](https://charmhub.io/discourse-k8s/configure). \ No newline at end of file diff --git a/lib/charms/saml_integrator/v0/saml.py b/lib/charms/saml_integrator/v0/saml.py new file mode 100644 index 00000000..f0cd3726 --- /dev/null +++ b/lib/charms/saml_integrator/v0/saml.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Canonical Ltd. +# Licensed under the Apache2.0. See LICENSE file in charm source for details. + +"""Library to manage the relation data for the SAML Integrator charm. + +This library contains the Requires and Provides classes for handling the relation +between an application and a charm providing the `saml`relation. +It also contains a `SamlRelationData` class to wrap the SAML data that will +be shared via the relation. + +### Requirer Charm + +```python + +from charms.saml_integrator.v0 import SamlDataAvailableEvent, SamlRequires + +class SamlRequirerCharm(ops.CharmBase): + def __init__(self, *args): + super().__init__(*args) + self.saml = saml.SamlRequires(self) + self.framework.observe(self.saml.on.saml_data_available, self._handler) + ... + + def _handler(self, events: SamlDataAvailableEvent) -> None: + ... + +``` + +As shown above, the library provides a custom event to handle the scenario in +which new SAML data has been added or updated. + +### Provider Charm + +Following the previous example, this is an example of the provider charm. + +```python +from charms.saml_integrator.v0 import SamlDataAvailableEvent, SamlRequires + +class SamlRequirerCharm(ops.CharmBase): + def __init__(self, *args): + super().__init__(*args) + self.saml = SamlRequires(self) + self.framework.observe(self.saml.on.saml_data_available, self._on_saml_data_available) + ... + + def _on_saml_data_available(self, events: SamlDataAvailableEvent) -> None: + ... + + def __init__(self, *args): + super().__init__(*args) + self.saml = SamlProvides(self) + +``` +The SamlProvides object wraps the list of relations into a `relations` property +and provides an `update_relation_data` method to update the relation data by passing +a `SamlRelationData` data object. +""" + +# The unique Charmhub library identifier, never change it +LIBID = "511cdfa7de3d43568bf9b512f9c9f89d" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 5 + +# pylint: disable=wrong-import-position +import re +import typing + +import ops +from pydantic import AnyHttpUrl, BaseModel, Field +from pydantic.tools import parse_obj_as + +DEFAULT_RELATION_NAME = "saml" + + +class SamlEndpoint(BaseModel): + """Represent a SAML endpoint. + + Attrs: + name: Endpoint name. + url: Endpoint URL. + binding: Endpoint binding. + response_url: URL to address the response to. + """ + + name: str = Field(..., min_length=1) + url: AnyHttpUrl + binding: str = Field(..., min_length=1) + response_url: typing.Optional[AnyHttpUrl] + + def to_relation_data(self) -> typing.Dict[str, str]: + """Convert an instance of SamlEndpoint to the relation representation. + + Returns: + Dict containing the representation. + """ + result: typing.Dict[str, str] = {} + # Get the HTTP method from the SAML binding + http_method = self.binding.split(":")[-1].split("-")[-1].lower() + # Transform name into snakecase + lowercase_name = re.sub(r"(? "SamlEndpoint": + """Initialize a new instance of the SamlEndpoint class from the relation data. + + Args: + relation_data: the relation data. + + Returns: + A SamlEndpoint instance. + """ + url_key = "" + for key in relation_data: + # A key per method and entpoint type that is always present + if key.endswith("_redirect_url") or key.endswith("_post_url"): + url_key = key + # Get endpoint name from the relation data key + lowercase_name = "_".join(url_key.split("_")[:-2]) + name = "".join(x.capitalize() for x in lowercase_name.split("_")) + # Get HTTP method from the relation data key + http_method = url_key.split("_")[-2] + prefix = f"{lowercase_name}_{http_method}_" + return cls( + name=name, + url=parse_obj_as(AnyHttpUrl, relation_data[f"{prefix}url"]), + binding=relation_data[f"{prefix}binding"], + response_url=( + parse_obj_as(AnyHttpUrl, relation_data[f"{prefix}response_url"]) + if f"{prefix}response_url" in relation_data + else None + ), + ) + + +class SamlRelationData(BaseModel): + """Represent the relation data. + + Attrs: + entity_id: SAML entity ID. + metadata_url: URL to the metadata. + certificates: List of SAML certificates. + endpoints: List of SAML endpoints. + """ + + entity_id: str = Field(..., min_length=1) + metadata_url: AnyHttpUrl + certificates: typing.List[str] + endpoints: typing.List[SamlEndpoint] + + def to_relation_data(self) -> typing.Dict[str, str]: + """Convert an instance of SamlDataAvailableEvent to the relation representation. + + Returns: + Dict containing the representation. + """ + result = { + "entity_id": self.entity_id, + "metadata_url": str(self.metadata_url), + "x509certs": ",".join(self.certificates), + } + for endpoint in self.endpoints: + result.update(endpoint.to_relation_data()) + return result + + +class SamlDataAvailableEvent(ops.RelationEvent): + """Saml event emitted when relation data has changed. + + Attrs: + entity_id: SAML entity ID. + metadata_url: URL to the metadata. + certificates: Tuple containing the SAML certificates. + endpoints: Tuple containing the SAML endpoints. + """ + + @property + def entity_id(self) -> str: + """Fetch the SAML entity ID from the relation.""" + assert self.relation.app + return self.relation.data[self.relation.app].get("entity_id") + + @property + def metadata_url(self) -> str: + """Fetch the SAML metadata URL from the relation.""" + assert self.relation.app + return parse_obj_as(AnyHttpUrl, self.relation.data[self.relation.app].get("metadata_url")) + + @property + def certificates(self) -> typing.Tuple[str, ...]: + """Fetch the SAML certificates from the relation.""" + assert self.relation.app + return tuple(self.relation.data[self.relation.app].get("x509certs").split(",")) + + @property + def endpoints(self) -> typing.Tuple[SamlEndpoint, ...]: + """Fetch the SAML endpoints from the relation.""" + assert self.relation.app + relation_data = self.relation.data[self.relation.app] + endpoints = [ + SamlEndpoint.from_relation_data( + { + key2: relation_data.get(key2) + for key2 in relation_data + if key2.startswith("_".join(key.split("_")[:-1])) + } + ) + for key in relation_data + if key.endswith("_redirect_url") or key.endswith("_post_url") + ] + endpoints.sort(key=lambda ep: ep.name) + return tuple(endpoints) + + +class SamlRequiresEvents(ops.CharmEvents): + """SAML events. + + This class defines the events that a SAML requirer can emit. + + Attrs: + saml_data_available: the SamlDataAvailableEvent. + """ + + saml_data_available = ops.EventSource(SamlDataAvailableEvent) + + +class SamlRequires(ops.Object): + """Requirer side of the SAML relation. + + Attrs: + on: events the provider can emit. + """ + + on = SamlRequiresEvents() + + def __init__(self, charm: ops.CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None: + """Construct. + + Args: + charm: the provider charm. + relation_name: the relation name. + """ + super().__init__(charm, relation_name) + self.charm = charm + self.relation_name = relation_name + self.framework.observe(charm.on[relation_name].relation_changed, self._on_relation_changed) + + def _on_relation_changed(self, event: ops.RelationChangedEvent) -> None: + """Event emitted when the relation has changed. + + Args: + event: event triggering this handler. + """ + assert event.relation.app + if event.relation.data[event.relation.app]: + self.on.saml_data_available.emit(event.relation, app=event.app, unit=event.unit) + + +class SamlProvides(ops.Object): + """Provider side of the SAML relation. + + Attrs: + relations: list of charm relations. + """ + + def __init__(self, charm: ops.CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None: + """Construct. + + Args: + charm: the provider charm. + relation_name: the relation name. + """ + super().__init__(charm, relation_name) + self.charm = charm + self.relation_name = relation_name + + @property + def relations(self) -> typing.List[ops.Relation]: + """The list of Relation instances associated with this relation_name. + + Returns: + List of relations to this charm. + """ + return list(self.model.relations[self.relation_name]) + + def update_relation_data(self, relation: ops.Relation, saml_data: SamlRelationData) -> None: + """Update the relation data. + + Args: + relation: the relation for which to update the data. + saml_data: a SamlRelationData instance wrapping the data to be updated. + """ + relation.data[self.charm.model.app].update(saml_data.to_relation_data()) diff --git a/metadata.yaml b/metadata.yaml index b3029c4c..09a952d4 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -51,6 +51,10 @@ requires: limit: 1 logging: interface: loki_push_api + saml: + interface: saml + limit: 1 + optional: true assumes: - k8s-api diff --git a/pyproject.toml b/pyproject.toml index 627382bb..1f11d215 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,9 +18,6 @@ show_missing = true [tool.pytest.ini_options] minversion = "6.0" log_cli_level = "INFO" -markers = [ - "requires_secrets: mark tests that require external secrets" -] # Formatting tools configuration [tool.black] diff --git a/requirements.txt b/requirements.txt index ca8415da..a48ab584 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ ops-lib-pgsql +pydantic==1.10.14 diff --git a/src/charm.py b/src/charm.py index 128f91eb..24658d77 100755 --- a/src/charm.py +++ b/src/charm.py @@ -3,6 +3,8 @@ # See LICENSE file for licensing details. """Charm for Discourse on kubernetes.""" +import base64 +import hashlib import logging import os.path import typing @@ -19,6 +21,11 @@ from charms.prometheus_k8s.v0.prometheus_scrape import MetricsEndpointProvider from charms.redis_k8s.v0.redis import RedisRelationCharmEvents, RedisRequires from charms.rolling_ops.v0.rollingops import RollingOpsManager +from charms.saml_integrator.v0.saml import ( + DEFAULT_RELATION_NAME, + SamlDataAvailableEvent, + SamlRequires, +) from ops.charm import ActionEvent, CharmBase, HookEvent, RelationBrokenEvent from ops.framework import StoredState from ops.main import main @@ -103,6 +110,8 @@ def __init__(self, *args): redis_relation={}, ) self._require_nginx_route() + self.saml = SamlRequires(self) + self.framework.observe(self.saml.on.saml_data_available, self._on_saml_data_available) self.framework.observe(self.on.start, self._on_start) self.framework.observe(self.on.upgrade_charm, self._on_upgrade_charm) @@ -193,6 +202,22 @@ def _on_config_changed(self, _: HookEvent) -> None: """ self._configure_pod() + def _on_saml_data_available(self, event: SamlDataAvailableEvent) -> None: + """Handle SAML data available.""" + if self.unit.is_leader(): + # Utilizing the SHA1 hash is safe in this case, so a nosec ignore will be put in place. + fingerprint = hashlib.sha1( + base64.b64decode(event.certificates[0]) + ).hexdigest() # nosec + relation = self.model.get_relation(DEFAULT_RELATION_NAME) + # Will ignore union-attr since asserting the relation type will make bandit complain. + relation.data[self.app].update( # type: ignore[union-attr] + { + "fingerprint": fingerprint, + } + ) + self._on_config_changed(event) + def _on_rolling_restart(self, _: ops.EventBase) -> None: """Handle rolling restart event. @@ -258,14 +283,23 @@ def _is_config_valid(self) -> bool: if self.config["throttle_level"] not in THROTTLE_LEVELS: errors.append(f"throttle_level must be one of: {' '.join(THROTTLE_LEVELS.keys())}") - if self.config["force_saml_login"] and not self.config["saml_target_url"]: - errors.append("force_saml_login can not be true without a saml_target_url") + if ( + self.config["force_saml_login"] + and self.model.get_relation(DEFAULT_RELATION_NAME) is None + ): + errors.append("force_saml_login cannot be true without a saml relation") - if self.config["saml_sync_groups"] and not self.config["saml_target_url"]: - errors.append("'saml_sync_groups' cannot be specified without a 'saml_target_url'") + if ( + self.config["saml_sync_groups"] + and self.model.get_relation(DEFAULT_RELATION_NAME) is None + ): + errors.append("'saml_sync_groups' cannot be specified without a saml relation") - if self.config["saml_target_url"] and not self.config["force_https"]: - errors.append("'saml_target_url' cannot be specified without 'force_https' being true") + if ( + self.model.get_relation(DEFAULT_RELATION_NAME) is not None + and not self.config["force_https"] + ): + errors.append("A saml relation cannot be specified without 'force_https' being true") if self.config.get("s3_enabled"): errors.extend( @@ -284,22 +318,22 @@ def _get_saml_config(self) -> typing.Dict[str, typing.Any]: Returns: Dictionary with the SAML configuration settings.. """ - ubuntu_one_fingerprint = "32:15:20:9F:A4:3C:8E:3E:8E:47:72:62:9A:86:8D:0E:E6:CF:45:D5" - ubuntu_one_staging_fingerprint = ( - "D2:B4:86:49:1B:AC:29:F6:A4:C8:CF:0D:3A:8F:AD:86:36:0A:77:C0" - ) - saml_fingerprints = { - "https://login.ubuntu.com/+saml": ubuntu_one_fingerprint, - "https://login.staging.ubuntu.com/+saml": ubuntu_one_staging_fingerprint, - } saml_config = {} - if self.config.get("saml_target_url"): - saml_config["DISCOURSE_SAML_TARGET_URL"] = self.config["saml_target_url"] + relation = self.model.get_relation(DEFAULT_RELATION_NAME) + if ( + relation is not None + and relation.data[self.app] + and relation.app + and relation.data[relation.app] + ): + saml_config["DISCOURSE_SAML_TARGET_URL"] = relation.data[relation.app][ + "single_sign_on_service_redirect_url" + ] saml_config["DISCOURSE_SAML_FULL_SCREEN_LOGIN"] = ( "true" if self.config["force_saml_login"] else "false" ) - fingerprint = saml_fingerprints.get(self.config["saml_target_url"]) + fingerprint = relation.data[self.app].get("fingerprint") if fingerprint: saml_config["DISCOURSE_SAML_CERT_FINGERPRINT"] = fingerprint diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f8dcfb5b..9c477190 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -149,8 +149,7 @@ async def app_fixture( redis_app = await model.deploy("redis-k8s", series="jammy", channel="latest/edge") await model.wait_for_idle(apps=[redis_app.name], status="active") - nii_app = await model.deploy("nginx-ingress-integrator", series="focal", trust=True) - await model.wait_for_idle(apps=[nii_app.name], status="waiting") + await model.deploy("nginx-ingress-integrator", series="focal", trust=True) resources = { "discourse-image": pytestconfig.getoption("--discourse-image"), @@ -204,13 +203,10 @@ async def setup_saml_config(app: Application, model: Model): discourse_app = model.applications[app.name] original_config: dict = await discourse_app.get_config() original_config = {k: v["value"] for k, v in original_config.items()} - await discourse_app.set_config( - {"saml_target_url": "https://login.staging.ubuntu.com/+saml", "force_https": "true"} - ) + await discourse_app.set_config({"force_https": "true"}) yield await discourse_app.set_config( { - "saml_target_url": original_config["saml_target_url"], "force_https": str(original_config["force_https"]).lower(), } ) diff --git a/tests/integration/test_charm.py b/tests/integration/test_charm.py index 4e2ad65e..b242deab 100644 --- a/tests/integration/test_charm.py +++ b/tests/integration/test_charm.py @@ -16,6 +16,7 @@ from botocore.config import Config from ops.model import ActiveStatus, Application from pytest_operator.plugin import Model +from saml_test_helper import SamlK8sTestHelper # pylint: disable=import-error from charm import PROMETHEUS_PORT @@ -179,32 +180,80 @@ def generate_s3_config(localstack_address: str) -> Dict: } +@pytest.mark.asyncio +async def test_create_category( + discourse_address: str, + admin_credentials: types.Credentials, + admin_api_key: str, +): + """ + arrange: Given discourse application and an admin user + act: if an admin user creates a category + assert: a category should be created normally. + """ + category_info = {"name": "test", "color": "FFFFFF"} + res = requests.post( + f"{discourse_address}/categories.json", + headers={ + "Api-Key": admin_api_key, + "Api-Username": admin_credentials.username, + }, + json=category_info, + timeout=60, + ) + category_id = res.json()["category"]["id"] + category = requests.get(f"{discourse_address}/c/{category_id}/show.json", timeout=60).json()[ + "category" + ] + + assert category["name"] == category_info["name"] + assert category["color"] == category_info["color"] + + @pytest.mark.asyncio @pytest.mark.abort_on_fail -@pytest.mark.requires_secrets @pytest.mark.usefixtures("setup_saml_config") async def test_saml_login( # pylint: disable=too-many-locals,too-many-arguments app: Application, requests_timeout: int, run_action, model: Model, - saml_email: str, - saml_password: str, ): """ arrange: after discourse charm has been deployed, with all required relation established. act: add an admin user and enable force-https mode. assert: user can login discourse using SAML Authentication. """ - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - action_result = await run_action( - app.name, "add-admin-user", email=saml_email, password=saml_password + saml_helper = SamlK8sTestHelper.deploy_saml_idp(model.name) + saml_app: Application = await model.deploy( + "saml-integrator", + channel="latest/edge", + series="jammy", + trust=True, ) + await model.wait_for_idle() + saml_helper.prepare_pod(model.name, f"{saml_app.name}-0") + saml_helper.prepare_pod(model.name, f"{app.name}-0") + await model.wait_for_idle() + await saml_app.set_config( # type: ignore[attr-defined] + { + "entity_id": saml_helper.entity_id, + "metadata_url": saml_helper.metadata_url, + } + ) + await model.add_relation(app.name, "saml-integrator") + await model.wait_for_idle() + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + # discourse need a long password and a valid email + # username can't be "discourse" or it will be renamed + username = "ubuntu" + email = "ubuntu@canonical.com" + password = "test-discourse-k8s-password" # nosec + saml_helper.register_user(username=username, email=email, password=password) + + action_result = await run_action(app.name, "add-admin-user", email=email, password=password) assert "user" in action_result - await model.wait_for_idle(status="active") - - username = saml_email.split("@")[0] host = app.name original_getaddrinfo = socket.getaddrinfo @@ -215,6 +264,14 @@ def patched_getaddrinfo(*args): with unittest.mock.patch.multiple(socket, getaddrinfo=patched_getaddrinfo): session = requests.session() + + response = session.get( + f"https://{host}/auth/saml/metadata", + verify=False, + timeout=10, + ) + saml_helper.register_service_provider(name=host, metadata=response.text) + preference_page = session.get( f"https://{host}/u/{username}/preferences/account", verify=False, @@ -229,51 +286,23 @@ def patched_getaddrinfo(*args): timeout=requests_timeout, ) csrf_token = response.json()["csrf"] - login_page = session.post( + redirect_response = session.post( f"https://{host}/auth/saml", data={"authenticity_token": csrf_token}, timeout=requests_timeout, + allow_redirects=False, ) - csrf_tokens = re.findall( - "", login_page.text - ) - assert len(csrf_tokens), login_page.text - csrf_token = csrf_tokens[0] - saml_callback = session.post( - "https://login.staging.ubuntu.com/+login", - data={ - "csrfmiddlewaretoken": csrf_token, - "email": saml_email, - "user-intentions": "login", - "password": saml_password, - "next": "/saml/process", - "continue": "", - "openid.usernamesecret": "", - }, - headers={"Referer": login_page.url}, - timeout=requests_timeout, - ) - saml_responses = re.findall( - '', saml_callback.text - ) - assert len(saml_responses), saml_callback.text - saml_response = saml_responses[0] - session.post( - f"https://{host}/auth/saml/callback", - data={ - "RelayState": "None", - "SAMLResponse": saml_response, - "openid.usernamesecret": "", - }, - verify=False, - timeout=requests_timeout, + assert redirect_response.status_code == 302 + redirect_url = redirect_response.headers["Location"] + saml_response = saml_helper.redirect_sso_login( + redirect_url, username=username, password=password ) + assert f"https://{host}" in saml_response.url session.post( - f"https://{host}/auth/saml/callback", - data={"SAMLResponse": saml_response, "SameSite": "1"}, - verify=False, - timeout=requests_timeout, + saml_response.url, + data={"SAMLResponse": saml_response.data["SAMLResponse"], "SameSite": "1"}, ) + session.post(saml_response.url, data=saml_response.data) preference_page = session.get( f"https://{host}/u/{username}/preferences/account", @@ -283,36 +312,6 @@ def patched_getaddrinfo(*args): assert preference_page.status_code == 200 -@pytest.mark.asyncio -async def test_create_category( - discourse_address: str, - admin_credentials: types.Credentials, - admin_api_key: str, -): - """ - arrange: Given discourse application and an admin user - act: if an admin user creates a category - assert: a category should be created normally. - """ - category_info = {"name": "test", "color": "FFFFFF"} - res = requests.post( - f"{discourse_address}/categories.json", - headers={ - "Api-Key": admin_api_key, - "Api-Username": admin_credentials.username, - }, - json=category_info, - timeout=60, - ) - category_id = res.json()["category"]["id"] - category = requests.get(f"{discourse_address}/c/{category_id}/show.json", timeout=60).json()[ - "category" - ] - - assert category["name"] == category_info["name"] - assert category["color"] == category_info["color"] - - @pytest.mark.asyncio async def test_serve_compiled_assets( discourse_address: str, diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 3ac35c93..6cbc79fb 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -17,6 +17,7 @@ def start_harness( *, + saml_fields: tuple = (False, "", ""), with_postgres: bool = True, with_redis: bool = True, with_ingress: bool = False, @@ -51,6 +52,9 @@ def start_harness( if with_ingress: _add_ingress_relation(harness) + if saml_fields[0]: + _add_saml_relation(harness, saml_fields[1], saml_fields[2]) + if with_config is not None: harness.update_config(with_config) @@ -130,3 +134,33 @@ def _add_ingress_relation(harness): """ nginx_route_relation_id = harness.add_relation("nginx-route", "ingress") harness.add_relation_unit(nginx_route_relation_id, "ingress/0") + + +def _add_saml_relation(harness, saml_target, fingerprint): + """Add ingress relation and relation data to the charm. + + Args: + - A harness instance + + Returns: the same harness instance with an added relation + """ + harness.set_leader(True) + saml_relation_id = harness.add_relation("saml", "saml-integrator") + harness.add_relation_unit(saml_relation_id, "saml-integrator/0") + harness.disable_hooks() + harness.update_relation_data( + relation_id=saml_relation_id, + app_or_unit="saml-integrator", + key_values={ + "single_sign_on_service_redirect_url": f"{saml_target}/+saml", + }, + ) + harness.enable_hooks() + harness.update_relation_data( + relation_id=saml_relation_id, + app_or_unit=harness.charm.app.name, + key_values={ + "entity_id": saml_target, + "fingerprint": fingerprint, + }, + ) diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index b65ee0d8..f3e2c48a 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -71,9 +71,9 @@ def test_on_config_changed_when_no_saml_target(): act: when force_saml_login configuration is True and there's no saml_target_url assert: it will get to blocked status waiting for the latter. """ - harness = helpers.start_harness(with_config={"force_saml_login": True, "saml_target_url": ""}) + harness = helpers.start_harness(with_config={"force_saml_login": True}) assert harness.model.unit.status == BlockedStatus( - "force_saml_login can not be true without a saml_target_url" + "force_saml_login cannot be true without a saml relation" ) @@ -83,11 +83,9 @@ def test_on_config_changed_when_saml_sync_groups_and_no_url_invalid(): act: when saml_sync_groups configuration is provided and there's no saml_target_url assert: it will get to blocked status waiting for the latter. """ - harness = helpers.start_harness( - with_config={"saml_sync_groups": "group1", "saml_target_url": ""} - ) + harness = helpers.start_harness(with_config={"saml_sync_groups": "group1"}) assert harness.model.unit.status == BlockedStatus( - "'saml_sync_groups' cannot be specified without a 'saml_target_url'" + "'saml_sync_groups' cannot be specified without a saml relation" ) @@ -97,11 +95,10 @@ def test_on_config_changed_when_saml_target_url_and_force_https_disabled(): act: when saml_target_url configuration is provided and force_https is False assert: it will get to blocked status waiting for the latter. """ - harness = helpers.start_harness( - with_config={"saml_target_url": "group1", "force_https": False} - ) + harness = helpers.start_harness(with_config={"force_https": False}, saml_fields=(True, "", "")) + harness.charm._is_config_valid() assert harness.model.unit.status == BlockedStatus( - "'saml_target_url' cannot be specified without 'force_https' being true" + "A saml relation cannot be specified without 'force_https' being true" ) @@ -229,13 +226,14 @@ def test_on_config_changed_when_valid_no_fingerprint(): harness = helpers.start_harness( with_config={ "force_saml_login": True, - "saml_target_url": "https://login.sample.com/+saml", "saml_sync_groups": "group1", "s3_enabled": False, "force_https": True, - } + }, + saml_fields=(True, "https://login.sample.com", ""), ) - harness.container_pebble_ready("discourse") + + harness.container_pebble_ready(SERVICE_NAME) updated_plan = harness.get_container_pebble_plan(SERVICE_NAME).to_dict() updated_plan_env = updated_plan["services"][SERVICE_NAME]["environment"] @@ -274,7 +272,6 @@ def test_on_config_changed_when_valid(): "enable_cors": True, "external_hostname": "discourse.local", "force_saml_login": True, - "saml_target_url": "https://login.ubuntu.com/+saml", "saml_sync_groups": "group1", "smtp_address": "smtp.internal", "smtp_domain": "foo.internal", @@ -289,9 +286,10 @@ def test_on_config_changed_when_valid(): "s3_region": "the-infinite-and-beyond", "s3_secret_access_key": "s|kI0ure_k3Y", "force_https": True, - } + }, + saml_fields=(True, "https://login.ubuntu.com", "fingerprint"), ) - harness.container_pebble_ready("discourse") + harness.container_pebble_ready(SERVICE_NAME) updated_plan = harness.get_container_pebble_plan(SERVICE_NAME).to_dict() updated_plan_env = updated_plan["services"][SERVICE_NAME]["environment"] diff --git a/tox.ini b/tox.ini index 81e0c5c2..8509f327 100644 --- a/tox.ini +++ b/tox.ini @@ -79,7 +79,7 @@ commands = [testenv:unit] description = Run unit tests deps = - cosl + cosl==0.0.8 pytest coverage[toml] ops>=2.6.0 @@ -114,10 +114,12 @@ deps = cosl juju>=3.0 protobuf==3.20.3 + pydantic==1.10.14 pytest pytest-operator pytest-asyncio psycopg2-binary + git+https://github.com/canonical/saml-test-idp.git -r{toxinidir}/requirements.txt commands = pytest -v --tb native --ignore={[vars]tst_path}unit --log-cli-level=INFO -s {posargs}