diff --git a/.github/workflows/integration_test.yaml b/.github/workflows/integration_test.yaml
index 7fa64e2c..763d1932 100644
--- a/.github/workflows/integration_test.yaml
+++ b/.github/workflows/integration_test.yaml
@@ -8,7 +8,7 @@ jobs:
uses: canonical/operator-workflows/.github/workflows/integration_test.yaml@main
secrets: inherit
with:
- extra-arguments: --localstack-address 172.17.0.1 -m "not (requires_secrets)"
+ extra-arguments: --localstack-address 172.17.0.1
pre-run-script: localstack-installation.sh
trivy-image-config: "trivy.yaml"
juju-channel: 3.1/stable
diff --git a/.github/workflows/integration_test_with_secrets.yaml b/.github/workflows/integration_test_with_secrets.yaml
deleted file mode 100644
index fa745931..00000000
--- a/.github/workflows/integration_test_with_secrets.yaml
+++ /dev/null
@@ -1,15 +0,0 @@
-name: Integration tests (require secrets)
-
-on:
- pull_request:
-
-jobs:
- integration-tests-with-secrets:
- uses: canonical/operator-workflows/.github/workflows/integration_test.yaml@main
- secrets: inherit
- with:
- extra-arguments: --localstack-address 172.17.0.1 -m "requires_secrets"
- pre-run-script: localstack-installation.sh
- trivy-image-config: "trivy.yaml"
- juju-channel: 3.1/stable
- channel: 1.28-strict/stable
diff --git a/config.yaml b/config.yaml
index a23dd440..8cf75602 100644
--- a/config.yaml
+++ b/config.yaml
@@ -25,10 +25,6 @@ options:
type: string
description: "Comma-separated list of groups to sync from SAML provider."
default: ""
- saml_target_url:
- type: string
- description: "SAML authentication target url."
- default: ""
smtp_address:
type: string
description: "Hostname / IP that should be used to send SMTP mail."
diff --git a/docs/how-to/configure-saml.md b/docs/how-to/configure-saml.md
index 93a5155f..aebd760a 100644
--- a/docs/how-to/configure-saml.md
+++ b/docs/how-to/configure-saml.md
@@ -2,7 +2,15 @@
To configure Discourse's SAML integration you'll have to set the following configuration options with the appropriate values for your SAML server by running `juju config [charm_name] [configuration]=[value]`.
-The SAML URL needs to be scpecified in `saml_target_url`. If you wish to force the login to go through SAML, enable `force_saml_login`.
+If you wish to force the login to go through SAML, enable `force_saml_login`.
The groups to be synced from the provider can be defined in `saml_sync_groups` as a comma-separated list of values.
+In order to implement the relation discourse has to be related with the [saml-integrator](https://charmhub.io/saml-integrator):
+```
+juju deploy saml-integrator --channel=edge
+# Set the SAML integrator configs
+juju config saml-integrator metadata_url=https://login.staging.ubuntu.com/saml/metadata
+juju config saml-integrator entity_id=https://login.staging.ubuntu.com
+juju integrate discourse-k8s saml-integrator
+```
For more details on the configuration options and their default values see the [configuration reference](https://charmhub.io/discourse-k8s/configure).
\ No newline at end of file
diff --git a/lib/charms/saml_integrator/v0/saml.py b/lib/charms/saml_integrator/v0/saml.py
new file mode 100644
index 00000000..f0cd3726
--- /dev/null
+++ b/lib/charms/saml_integrator/v0/saml.py
@@ -0,0 +1,305 @@
+#!/usr/bin/env python3
+
+# Copyright 2024 Canonical Ltd.
+# Licensed under the Apache2.0. See LICENSE file in charm source for details.
+
+"""Library to manage the relation data for the SAML Integrator charm.
+
+This library contains the Requires and Provides classes for handling the relation
+between an application and a charm providing the `saml`relation.
+It also contains a `SamlRelationData` class to wrap the SAML data that will
+be shared via the relation.
+
+### Requirer Charm
+
+```python
+
+from charms.saml_integrator.v0 import SamlDataAvailableEvent, SamlRequires
+
+class SamlRequirerCharm(ops.CharmBase):
+ def __init__(self, *args):
+ super().__init__(*args)
+ self.saml = saml.SamlRequires(self)
+ self.framework.observe(self.saml.on.saml_data_available, self._handler)
+ ...
+
+ def _handler(self, events: SamlDataAvailableEvent) -> None:
+ ...
+
+```
+
+As shown above, the library provides a custom event to handle the scenario in
+which new SAML data has been added or updated.
+
+### Provider Charm
+
+Following the previous example, this is an example of the provider charm.
+
+```python
+from charms.saml_integrator.v0 import SamlDataAvailableEvent, SamlRequires
+
+class SamlRequirerCharm(ops.CharmBase):
+ def __init__(self, *args):
+ super().__init__(*args)
+ self.saml = SamlRequires(self)
+ self.framework.observe(self.saml.on.saml_data_available, self._on_saml_data_available)
+ ...
+
+ def _on_saml_data_available(self, events: SamlDataAvailableEvent) -> None:
+ ...
+
+ def __init__(self, *args):
+ super().__init__(*args)
+ self.saml = SamlProvides(self)
+
+```
+The SamlProvides object wraps the list of relations into a `relations` property
+and provides an `update_relation_data` method to update the relation data by passing
+a `SamlRelationData` data object.
+"""
+
+# The unique Charmhub library identifier, never change it
+LIBID = "511cdfa7de3d43568bf9b512f9c9f89d"
+
+# Increment this major API version when introducing breaking changes
+LIBAPI = 0
+
+# Increment this PATCH version before using `charmcraft publish-lib` or reset
+# to 0 if you are raising the major API version
+LIBPATCH = 5
+
+# pylint: disable=wrong-import-position
+import re
+import typing
+
+import ops
+from pydantic import AnyHttpUrl, BaseModel, Field
+from pydantic.tools import parse_obj_as
+
+DEFAULT_RELATION_NAME = "saml"
+
+
+class SamlEndpoint(BaseModel):
+ """Represent a SAML endpoint.
+
+ Attrs:
+ name: Endpoint name.
+ url: Endpoint URL.
+ binding: Endpoint binding.
+ response_url: URL to address the response to.
+ """
+
+ name: str = Field(..., min_length=1)
+ url: AnyHttpUrl
+ binding: str = Field(..., min_length=1)
+ response_url: typing.Optional[AnyHttpUrl]
+
+ def to_relation_data(self) -> typing.Dict[str, str]:
+ """Convert an instance of SamlEndpoint to the relation representation.
+
+ Returns:
+ Dict containing the representation.
+ """
+ result: typing.Dict[str, str] = {}
+ # Get the HTTP method from the SAML binding
+ http_method = self.binding.split(":")[-1].split("-")[-1].lower()
+ # Transform name into snakecase
+ lowercase_name = re.sub(r"(? "SamlEndpoint":
+ """Initialize a new instance of the SamlEndpoint class from the relation data.
+
+ Args:
+ relation_data: the relation data.
+
+ Returns:
+ A SamlEndpoint instance.
+ """
+ url_key = ""
+ for key in relation_data:
+ # A key per method and entpoint type that is always present
+ if key.endswith("_redirect_url") or key.endswith("_post_url"):
+ url_key = key
+ # Get endpoint name from the relation data key
+ lowercase_name = "_".join(url_key.split("_")[:-2])
+ name = "".join(x.capitalize() for x in lowercase_name.split("_"))
+ # Get HTTP method from the relation data key
+ http_method = url_key.split("_")[-2]
+ prefix = f"{lowercase_name}_{http_method}_"
+ return cls(
+ name=name,
+ url=parse_obj_as(AnyHttpUrl, relation_data[f"{prefix}url"]),
+ binding=relation_data[f"{prefix}binding"],
+ response_url=(
+ parse_obj_as(AnyHttpUrl, relation_data[f"{prefix}response_url"])
+ if f"{prefix}response_url" in relation_data
+ else None
+ ),
+ )
+
+
+class SamlRelationData(BaseModel):
+ """Represent the relation data.
+
+ Attrs:
+ entity_id: SAML entity ID.
+ metadata_url: URL to the metadata.
+ certificates: List of SAML certificates.
+ endpoints: List of SAML endpoints.
+ """
+
+ entity_id: str = Field(..., min_length=1)
+ metadata_url: AnyHttpUrl
+ certificates: typing.List[str]
+ endpoints: typing.List[SamlEndpoint]
+
+ def to_relation_data(self) -> typing.Dict[str, str]:
+ """Convert an instance of SamlDataAvailableEvent to the relation representation.
+
+ Returns:
+ Dict containing the representation.
+ """
+ result = {
+ "entity_id": self.entity_id,
+ "metadata_url": str(self.metadata_url),
+ "x509certs": ",".join(self.certificates),
+ }
+ for endpoint in self.endpoints:
+ result.update(endpoint.to_relation_data())
+ return result
+
+
+class SamlDataAvailableEvent(ops.RelationEvent):
+ """Saml event emitted when relation data has changed.
+
+ Attrs:
+ entity_id: SAML entity ID.
+ metadata_url: URL to the metadata.
+ certificates: Tuple containing the SAML certificates.
+ endpoints: Tuple containing the SAML endpoints.
+ """
+
+ @property
+ def entity_id(self) -> str:
+ """Fetch the SAML entity ID from the relation."""
+ assert self.relation.app
+ return self.relation.data[self.relation.app].get("entity_id")
+
+ @property
+ def metadata_url(self) -> str:
+ """Fetch the SAML metadata URL from the relation."""
+ assert self.relation.app
+ return parse_obj_as(AnyHttpUrl, self.relation.data[self.relation.app].get("metadata_url"))
+
+ @property
+ def certificates(self) -> typing.Tuple[str, ...]:
+ """Fetch the SAML certificates from the relation."""
+ assert self.relation.app
+ return tuple(self.relation.data[self.relation.app].get("x509certs").split(","))
+
+ @property
+ def endpoints(self) -> typing.Tuple[SamlEndpoint, ...]:
+ """Fetch the SAML endpoints from the relation."""
+ assert self.relation.app
+ relation_data = self.relation.data[self.relation.app]
+ endpoints = [
+ SamlEndpoint.from_relation_data(
+ {
+ key2: relation_data.get(key2)
+ for key2 in relation_data
+ if key2.startswith("_".join(key.split("_")[:-1]))
+ }
+ )
+ for key in relation_data
+ if key.endswith("_redirect_url") or key.endswith("_post_url")
+ ]
+ endpoints.sort(key=lambda ep: ep.name)
+ return tuple(endpoints)
+
+
+class SamlRequiresEvents(ops.CharmEvents):
+ """SAML events.
+
+ This class defines the events that a SAML requirer can emit.
+
+ Attrs:
+ saml_data_available: the SamlDataAvailableEvent.
+ """
+
+ saml_data_available = ops.EventSource(SamlDataAvailableEvent)
+
+
+class SamlRequires(ops.Object):
+ """Requirer side of the SAML relation.
+
+ Attrs:
+ on: events the provider can emit.
+ """
+
+ on = SamlRequiresEvents()
+
+ def __init__(self, charm: ops.CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None:
+ """Construct.
+
+ Args:
+ charm: the provider charm.
+ relation_name: the relation name.
+ """
+ super().__init__(charm, relation_name)
+ self.charm = charm
+ self.relation_name = relation_name
+ self.framework.observe(charm.on[relation_name].relation_changed, self._on_relation_changed)
+
+ def _on_relation_changed(self, event: ops.RelationChangedEvent) -> None:
+ """Event emitted when the relation has changed.
+
+ Args:
+ event: event triggering this handler.
+ """
+ assert event.relation.app
+ if event.relation.data[event.relation.app]:
+ self.on.saml_data_available.emit(event.relation, app=event.app, unit=event.unit)
+
+
+class SamlProvides(ops.Object):
+ """Provider side of the SAML relation.
+
+ Attrs:
+ relations: list of charm relations.
+ """
+
+ def __init__(self, charm: ops.CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None:
+ """Construct.
+
+ Args:
+ charm: the provider charm.
+ relation_name: the relation name.
+ """
+ super().__init__(charm, relation_name)
+ self.charm = charm
+ self.relation_name = relation_name
+
+ @property
+ def relations(self) -> typing.List[ops.Relation]:
+ """The list of Relation instances associated with this relation_name.
+
+ Returns:
+ List of relations to this charm.
+ """
+ return list(self.model.relations[self.relation_name])
+
+ def update_relation_data(self, relation: ops.Relation, saml_data: SamlRelationData) -> None:
+ """Update the relation data.
+
+ Args:
+ relation: the relation for which to update the data.
+ saml_data: a SamlRelationData instance wrapping the data to be updated.
+ """
+ relation.data[self.charm.model.app].update(saml_data.to_relation_data())
diff --git a/metadata.yaml b/metadata.yaml
index b3029c4c..09a952d4 100644
--- a/metadata.yaml
+++ b/metadata.yaml
@@ -51,6 +51,10 @@ requires:
limit: 1
logging:
interface: loki_push_api
+ saml:
+ interface: saml
+ limit: 1
+ optional: true
assumes:
- k8s-api
diff --git a/pyproject.toml b/pyproject.toml
index 627382bb..1f11d215 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,9 +18,6 @@ show_missing = true
[tool.pytest.ini_options]
minversion = "6.0"
log_cli_level = "INFO"
-markers = [
- "requires_secrets: mark tests that require external secrets"
-]
# Formatting tools configuration
[tool.black]
diff --git a/requirements.txt b/requirements.txt
index ca8415da..a48ab584 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,2 @@
ops-lib-pgsql
+pydantic==1.10.14
diff --git a/src/charm.py b/src/charm.py
index 128f91eb..24658d77 100755
--- a/src/charm.py
+++ b/src/charm.py
@@ -3,6 +3,8 @@
# See LICENSE file for licensing details.
"""Charm for Discourse on kubernetes."""
+import base64
+import hashlib
import logging
import os.path
import typing
@@ -19,6 +21,11 @@
from charms.prometheus_k8s.v0.prometheus_scrape import MetricsEndpointProvider
from charms.redis_k8s.v0.redis import RedisRelationCharmEvents, RedisRequires
from charms.rolling_ops.v0.rollingops import RollingOpsManager
+from charms.saml_integrator.v0.saml import (
+ DEFAULT_RELATION_NAME,
+ SamlDataAvailableEvent,
+ SamlRequires,
+)
from ops.charm import ActionEvent, CharmBase, HookEvent, RelationBrokenEvent
from ops.framework import StoredState
from ops.main import main
@@ -103,6 +110,8 @@ def __init__(self, *args):
redis_relation={},
)
self._require_nginx_route()
+ self.saml = SamlRequires(self)
+ self.framework.observe(self.saml.on.saml_data_available, self._on_saml_data_available)
self.framework.observe(self.on.start, self._on_start)
self.framework.observe(self.on.upgrade_charm, self._on_upgrade_charm)
@@ -193,6 +202,22 @@ def _on_config_changed(self, _: HookEvent) -> None:
"""
self._configure_pod()
+ def _on_saml_data_available(self, event: SamlDataAvailableEvent) -> None:
+ """Handle SAML data available."""
+ if self.unit.is_leader():
+ # Utilizing the SHA1 hash is safe in this case, so a nosec ignore will be put in place.
+ fingerprint = hashlib.sha1(
+ base64.b64decode(event.certificates[0])
+ ).hexdigest() # nosec
+ relation = self.model.get_relation(DEFAULT_RELATION_NAME)
+ # Will ignore union-attr since asserting the relation type will make bandit complain.
+ relation.data[self.app].update( # type: ignore[union-attr]
+ {
+ "fingerprint": fingerprint,
+ }
+ )
+ self._on_config_changed(event)
+
def _on_rolling_restart(self, _: ops.EventBase) -> None:
"""Handle rolling restart event.
@@ -258,14 +283,23 @@ def _is_config_valid(self) -> bool:
if self.config["throttle_level"] not in THROTTLE_LEVELS:
errors.append(f"throttle_level must be one of: {' '.join(THROTTLE_LEVELS.keys())}")
- if self.config["force_saml_login"] and not self.config["saml_target_url"]:
- errors.append("force_saml_login can not be true without a saml_target_url")
+ if (
+ self.config["force_saml_login"]
+ and self.model.get_relation(DEFAULT_RELATION_NAME) is None
+ ):
+ errors.append("force_saml_login cannot be true without a saml relation")
- if self.config["saml_sync_groups"] and not self.config["saml_target_url"]:
- errors.append("'saml_sync_groups' cannot be specified without a 'saml_target_url'")
+ if (
+ self.config["saml_sync_groups"]
+ and self.model.get_relation(DEFAULT_RELATION_NAME) is None
+ ):
+ errors.append("'saml_sync_groups' cannot be specified without a saml relation")
- if self.config["saml_target_url"] and not self.config["force_https"]:
- errors.append("'saml_target_url' cannot be specified without 'force_https' being true")
+ if (
+ self.model.get_relation(DEFAULT_RELATION_NAME) is not None
+ and not self.config["force_https"]
+ ):
+ errors.append("A saml relation cannot be specified without 'force_https' being true")
if self.config.get("s3_enabled"):
errors.extend(
@@ -284,22 +318,22 @@ def _get_saml_config(self) -> typing.Dict[str, typing.Any]:
Returns:
Dictionary with the SAML configuration settings..
"""
- ubuntu_one_fingerprint = "32:15:20:9F:A4:3C:8E:3E:8E:47:72:62:9A:86:8D:0E:E6:CF:45:D5"
- ubuntu_one_staging_fingerprint = (
- "D2:B4:86:49:1B:AC:29:F6:A4:C8:CF:0D:3A:8F:AD:86:36:0A:77:C0"
- )
- saml_fingerprints = {
- "https://login.ubuntu.com/+saml": ubuntu_one_fingerprint,
- "https://login.staging.ubuntu.com/+saml": ubuntu_one_staging_fingerprint,
- }
saml_config = {}
- if self.config.get("saml_target_url"):
- saml_config["DISCOURSE_SAML_TARGET_URL"] = self.config["saml_target_url"]
+ relation = self.model.get_relation(DEFAULT_RELATION_NAME)
+ if (
+ relation is not None
+ and relation.data[self.app]
+ and relation.app
+ and relation.data[relation.app]
+ ):
+ saml_config["DISCOURSE_SAML_TARGET_URL"] = relation.data[relation.app][
+ "single_sign_on_service_redirect_url"
+ ]
saml_config["DISCOURSE_SAML_FULL_SCREEN_LOGIN"] = (
"true" if self.config["force_saml_login"] else "false"
)
- fingerprint = saml_fingerprints.get(self.config["saml_target_url"])
+ fingerprint = relation.data[self.app].get("fingerprint")
if fingerprint:
saml_config["DISCOURSE_SAML_CERT_FINGERPRINT"] = fingerprint
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index f8dcfb5b..9c477190 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -149,8 +149,7 @@ async def app_fixture(
redis_app = await model.deploy("redis-k8s", series="jammy", channel="latest/edge")
await model.wait_for_idle(apps=[redis_app.name], status="active")
- nii_app = await model.deploy("nginx-ingress-integrator", series="focal", trust=True)
- await model.wait_for_idle(apps=[nii_app.name], status="waiting")
+ await model.deploy("nginx-ingress-integrator", series="focal", trust=True)
resources = {
"discourse-image": pytestconfig.getoption("--discourse-image"),
@@ -204,13 +203,10 @@ async def setup_saml_config(app: Application, model: Model):
discourse_app = model.applications[app.name]
original_config: dict = await discourse_app.get_config()
original_config = {k: v["value"] for k, v in original_config.items()}
- await discourse_app.set_config(
- {"saml_target_url": "https://login.staging.ubuntu.com/+saml", "force_https": "true"}
- )
+ await discourse_app.set_config({"force_https": "true"})
yield
await discourse_app.set_config(
{
- "saml_target_url": original_config["saml_target_url"],
"force_https": str(original_config["force_https"]).lower(),
}
)
diff --git a/tests/integration/test_charm.py b/tests/integration/test_charm.py
index 4e2ad65e..b242deab 100644
--- a/tests/integration/test_charm.py
+++ b/tests/integration/test_charm.py
@@ -16,6 +16,7 @@
from botocore.config import Config
from ops.model import ActiveStatus, Application
from pytest_operator.plugin import Model
+from saml_test_helper import SamlK8sTestHelper # pylint: disable=import-error
from charm import PROMETHEUS_PORT
@@ -179,32 +180,80 @@ def generate_s3_config(localstack_address: str) -> Dict:
}
+@pytest.mark.asyncio
+async def test_create_category(
+ discourse_address: str,
+ admin_credentials: types.Credentials,
+ admin_api_key: str,
+):
+ """
+ arrange: Given discourse application and an admin user
+ act: if an admin user creates a category
+ assert: a category should be created normally.
+ """
+ category_info = {"name": "test", "color": "FFFFFF"}
+ res = requests.post(
+ f"{discourse_address}/categories.json",
+ headers={
+ "Api-Key": admin_api_key,
+ "Api-Username": admin_credentials.username,
+ },
+ json=category_info,
+ timeout=60,
+ )
+ category_id = res.json()["category"]["id"]
+ category = requests.get(f"{discourse_address}/c/{category_id}/show.json", timeout=60).json()[
+ "category"
+ ]
+
+ assert category["name"] == category_info["name"]
+ assert category["color"] == category_info["color"]
+
+
@pytest.mark.asyncio
@pytest.mark.abort_on_fail
-@pytest.mark.requires_secrets
@pytest.mark.usefixtures("setup_saml_config")
async def test_saml_login( # pylint: disable=too-many-locals,too-many-arguments
app: Application,
requests_timeout: int,
run_action,
model: Model,
- saml_email: str,
- saml_password: str,
):
"""
arrange: after discourse charm has been deployed, with all required relation established.
act: add an admin user and enable force-https mode.
assert: user can login discourse using SAML Authentication.
"""
- urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
- action_result = await run_action(
- app.name, "add-admin-user", email=saml_email, password=saml_password
+ saml_helper = SamlK8sTestHelper.deploy_saml_idp(model.name)
+ saml_app: Application = await model.deploy(
+ "saml-integrator",
+ channel="latest/edge",
+ series="jammy",
+ trust=True,
)
+ await model.wait_for_idle()
+ saml_helper.prepare_pod(model.name, f"{saml_app.name}-0")
+ saml_helper.prepare_pod(model.name, f"{app.name}-0")
+ await model.wait_for_idle()
+ await saml_app.set_config( # type: ignore[attr-defined]
+ {
+ "entity_id": saml_helper.entity_id,
+ "metadata_url": saml_helper.metadata_url,
+ }
+ )
+ await model.add_relation(app.name, "saml-integrator")
+ await model.wait_for_idle()
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
+ # discourse need a long password and a valid email
+ # username can't be "discourse" or it will be renamed
+ username = "ubuntu"
+ email = "ubuntu@canonical.com"
+ password = "test-discourse-k8s-password" # nosec
+ saml_helper.register_user(username=username, email=email, password=password)
+
+ action_result = await run_action(app.name, "add-admin-user", email=email, password=password)
assert "user" in action_result
- await model.wait_for_idle(status="active")
-
- username = saml_email.split("@")[0]
host = app.name
original_getaddrinfo = socket.getaddrinfo
@@ -215,6 +264,14 @@ def patched_getaddrinfo(*args):
with unittest.mock.patch.multiple(socket, getaddrinfo=patched_getaddrinfo):
session = requests.session()
+
+ response = session.get(
+ f"https://{host}/auth/saml/metadata",
+ verify=False,
+ timeout=10,
+ )
+ saml_helper.register_service_provider(name=host, metadata=response.text)
+
preference_page = session.get(
f"https://{host}/u/{username}/preferences/account",
verify=False,
@@ -229,51 +286,23 @@ def patched_getaddrinfo(*args):
timeout=requests_timeout,
)
csrf_token = response.json()["csrf"]
- login_page = session.post(
+ redirect_response = session.post(
f"https://{host}/auth/saml",
data={"authenticity_token": csrf_token},
timeout=requests_timeout,
+ allow_redirects=False,
)
- csrf_tokens = re.findall(
- "", login_page.text
- )
- assert len(csrf_tokens), login_page.text
- csrf_token = csrf_tokens[0]
- saml_callback = session.post(
- "https://login.staging.ubuntu.com/+login",
- data={
- "csrfmiddlewaretoken": csrf_token,
- "email": saml_email,
- "user-intentions": "login",
- "password": saml_password,
- "next": "/saml/process",
- "continue": "",
- "openid.usernamesecret": "",
- },
- headers={"Referer": login_page.url},
- timeout=requests_timeout,
- )
- saml_responses = re.findall(
- '', saml_callback.text
- )
- assert len(saml_responses), saml_callback.text
- saml_response = saml_responses[0]
- session.post(
- f"https://{host}/auth/saml/callback",
- data={
- "RelayState": "None",
- "SAMLResponse": saml_response,
- "openid.usernamesecret": "",
- },
- verify=False,
- timeout=requests_timeout,
+ assert redirect_response.status_code == 302
+ redirect_url = redirect_response.headers["Location"]
+ saml_response = saml_helper.redirect_sso_login(
+ redirect_url, username=username, password=password
)
+ assert f"https://{host}" in saml_response.url
session.post(
- f"https://{host}/auth/saml/callback",
- data={"SAMLResponse": saml_response, "SameSite": "1"},
- verify=False,
- timeout=requests_timeout,
+ saml_response.url,
+ data={"SAMLResponse": saml_response.data["SAMLResponse"], "SameSite": "1"},
)
+ session.post(saml_response.url, data=saml_response.data)
preference_page = session.get(
f"https://{host}/u/{username}/preferences/account",
@@ -283,36 +312,6 @@ def patched_getaddrinfo(*args):
assert preference_page.status_code == 200
-@pytest.mark.asyncio
-async def test_create_category(
- discourse_address: str,
- admin_credentials: types.Credentials,
- admin_api_key: str,
-):
- """
- arrange: Given discourse application and an admin user
- act: if an admin user creates a category
- assert: a category should be created normally.
- """
- category_info = {"name": "test", "color": "FFFFFF"}
- res = requests.post(
- f"{discourse_address}/categories.json",
- headers={
- "Api-Key": admin_api_key,
- "Api-Username": admin_credentials.username,
- },
- json=category_info,
- timeout=60,
- )
- category_id = res.json()["category"]["id"]
- category = requests.get(f"{discourse_address}/c/{category_id}/show.json", timeout=60).json()[
- "category"
- ]
-
- assert category["name"] == category_info["name"]
- assert category["color"] == category_info["color"]
-
-
@pytest.mark.asyncio
async def test_serve_compiled_assets(
discourse_address: str,
diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py
index 3ac35c93..6cbc79fb 100644
--- a/tests/unit/helpers.py
+++ b/tests/unit/helpers.py
@@ -17,6 +17,7 @@
def start_harness(
*,
+ saml_fields: tuple = (False, "", ""),
with_postgres: bool = True,
with_redis: bool = True,
with_ingress: bool = False,
@@ -51,6 +52,9 @@ def start_harness(
if with_ingress:
_add_ingress_relation(harness)
+ if saml_fields[0]:
+ _add_saml_relation(harness, saml_fields[1], saml_fields[2])
+
if with_config is not None:
harness.update_config(with_config)
@@ -130,3 +134,33 @@ def _add_ingress_relation(harness):
"""
nginx_route_relation_id = harness.add_relation("nginx-route", "ingress")
harness.add_relation_unit(nginx_route_relation_id, "ingress/0")
+
+
+def _add_saml_relation(harness, saml_target, fingerprint):
+ """Add ingress relation and relation data to the charm.
+
+ Args:
+ - A harness instance
+
+ Returns: the same harness instance with an added relation
+ """
+ harness.set_leader(True)
+ saml_relation_id = harness.add_relation("saml", "saml-integrator")
+ harness.add_relation_unit(saml_relation_id, "saml-integrator/0")
+ harness.disable_hooks()
+ harness.update_relation_data(
+ relation_id=saml_relation_id,
+ app_or_unit="saml-integrator",
+ key_values={
+ "single_sign_on_service_redirect_url": f"{saml_target}/+saml",
+ },
+ )
+ harness.enable_hooks()
+ harness.update_relation_data(
+ relation_id=saml_relation_id,
+ app_or_unit=harness.charm.app.name,
+ key_values={
+ "entity_id": saml_target,
+ "fingerprint": fingerprint,
+ },
+ )
diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py
index b65ee0d8..f3e2c48a 100644
--- a/tests/unit/test_charm.py
+++ b/tests/unit/test_charm.py
@@ -71,9 +71,9 @@ def test_on_config_changed_when_no_saml_target():
act: when force_saml_login configuration is True and there's no saml_target_url
assert: it will get to blocked status waiting for the latter.
"""
- harness = helpers.start_harness(with_config={"force_saml_login": True, "saml_target_url": ""})
+ harness = helpers.start_harness(with_config={"force_saml_login": True})
assert harness.model.unit.status == BlockedStatus(
- "force_saml_login can not be true without a saml_target_url"
+ "force_saml_login cannot be true without a saml relation"
)
@@ -83,11 +83,9 @@ def test_on_config_changed_when_saml_sync_groups_and_no_url_invalid():
act: when saml_sync_groups configuration is provided and there's no saml_target_url
assert: it will get to blocked status waiting for the latter.
"""
- harness = helpers.start_harness(
- with_config={"saml_sync_groups": "group1", "saml_target_url": ""}
- )
+ harness = helpers.start_harness(with_config={"saml_sync_groups": "group1"})
assert harness.model.unit.status == BlockedStatus(
- "'saml_sync_groups' cannot be specified without a 'saml_target_url'"
+ "'saml_sync_groups' cannot be specified without a saml relation"
)
@@ -97,11 +95,10 @@ def test_on_config_changed_when_saml_target_url_and_force_https_disabled():
act: when saml_target_url configuration is provided and force_https is False
assert: it will get to blocked status waiting for the latter.
"""
- harness = helpers.start_harness(
- with_config={"saml_target_url": "group1", "force_https": False}
- )
+ harness = helpers.start_harness(with_config={"force_https": False}, saml_fields=(True, "", ""))
+ harness.charm._is_config_valid()
assert harness.model.unit.status == BlockedStatus(
- "'saml_target_url' cannot be specified without 'force_https' being true"
+ "A saml relation cannot be specified without 'force_https' being true"
)
@@ -229,13 +226,14 @@ def test_on_config_changed_when_valid_no_fingerprint():
harness = helpers.start_harness(
with_config={
"force_saml_login": True,
- "saml_target_url": "https://login.sample.com/+saml",
"saml_sync_groups": "group1",
"s3_enabled": False,
"force_https": True,
- }
+ },
+ saml_fields=(True, "https://login.sample.com", ""),
)
- harness.container_pebble_ready("discourse")
+
+ harness.container_pebble_ready(SERVICE_NAME)
updated_plan = harness.get_container_pebble_plan(SERVICE_NAME).to_dict()
updated_plan_env = updated_plan["services"][SERVICE_NAME]["environment"]
@@ -274,7 +272,6 @@ def test_on_config_changed_when_valid():
"enable_cors": True,
"external_hostname": "discourse.local",
"force_saml_login": True,
- "saml_target_url": "https://login.ubuntu.com/+saml",
"saml_sync_groups": "group1",
"smtp_address": "smtp.internal",
"smtp_domain": "foo.internal",
@@ -289,9 +286,10 @@ def test_on_config_changed_when_valid():
"s3_region": "the-infinite-and-beyond",
"s3_secret_access_key": "s|kI0ure_k3Y",
"force_https": True,
- }
+ },
+ saml_fields=(True, "https://login.ubuntu.com", "fingerprint"),
)
- harness.container_pebble_ready("discourse")
+ harness.container_pebble_ready(SERVICE_NAME)
updated_plan = harness.get_container_pebble_plan(SERVICE_NAME).to_dict()
updated_plan_env = updated_plan["services"][SERVICE_NAME]["environment"]
diff --git a/tox.ini b/tox.ini
index 81e0c5c2..8509f327 100644
--- a/tox.ini
+++ b/tox.ini
@@ -79,7 +79,7 @@ commands =
[testenv:unit]
description = Run unit tests
deps =
- cosl
+ cosl==0.0.8
pytest
coverage[toml]
ops>=2.6.0
@@ -114,10 +114,12 @@ deps =
cosl
juju>=3.0
protobuf==3.20.3
+ pydantic==1.10.14
pytest
pytest-operator
pytest-asyncio
psycopg2-binary
+ git+https://github.com/canonical/saml-test-idp.git
-r{toxinidir}/requirements.txt
commands =
pytest -v --tb native --ignore={[vars]tst_path}unit --log-cli-level=INFO -s {posargs}