Skip to content

Commit 530cf91

Browse files
authored
Merge pull request #355 from NHSDigital/10291-logout-cis2
Add CIS2 back channel logout endpoint
2 parents 3de1cfd + 15e28f5 commit 530cf91

File tree

9 files changed

+562
-11
lines changed

9 files changed

+562
-11
lines changed

.gitleaksignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ infrastructure/terraform/resource_group_init/main.bicep:generic-api-key:32
1717
infrastructure/terraform/resource_group_init/main.bicep:generic-api-key:33
1818
infrastructure/terraform/resource_group_init/storage.bicep:generic-api-key:59
1919
infrastructure/terraform/resource_group_init/keyVault.bicep:generic-api-key:10
20+
manage_breast_screening/config/settings_test.py:private-key:34
21+
manage_breast_screening/config/settings_test.py:public-key:61
File renamed without changes.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import logging
2+
from typing import Any, Dict
3+
4+
from authlib.jose import JoseError, jwt
5+
from authlib.jose.errors import (
6+
ExpiredTokenError,
7+
InvalidClaimError,
8+
InvalidTokenError,
9+
MissingClaimError,
10+
)
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class InvalidLogoutToken(Exception):
16+
"""Raised when a CIS2 back-channel logout token is invalid."""
17+
18+
def __init__(self, cause=None):
19+
"""
20+
Initialize with optional cause.
21+
22+
Args:
23+
cause: The original exception that caused this error
24+
"""
25+
self.cause = cause
26+
super().__init__(str(cause) if cause else "Invalid logout token")
27+
28+
29+
class DecodeLogoutToken:
30+
def call(
31+
self,
32+
metadata: Dict[str, Any],
33+
logout_token: str,
34+
client_id: str,
35+
key_loader,
36+
):
37+
"""
38+
Decode and validate a CIS2 back-channel logout token.
39+
40+
Returns the decoded claims on success and raises InvalidLogoutToken on
41+
any underlying decoding/validation error.
42+
"""
43+
try:
44+
verification_rules = {
45+
"iss": {"values": [metadata["issuer"]], "essential": True},
46+
"aud": {"values": [client_id], "essential": True},
47+
"exp": {"essential": True},
48+
"iat": {"essential": True},
49+
"events": {
50+
"essential": True,
51+
"validate": lambda claim, value: isinstance(value, dict)
52+
and "http://schemas.openid.net/event/backchannel-logout" in value,
53+
},
54+
"nonce": {"validate": lambda claim, value: value is None},
55+
}
56+
claims = jwt.decode(
57+
logout_token,
58+
key=key_loader,
59+
claims_options=verification_rules,
60+
)
61+
claims.validate(leeway=60)
62+
return claims
63+
except (
64+
ExpiredTokenError,
65+
InvalidClaimError,
66+
InvalidTokenError,
67+
JoseError,
68+
MissingClaimError,
69+
) as e:
70+
logger.exception("Invalid logout token")
71+
raise InvalidLogoutToken(cause=e) from e
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import time
2+
3+
import pytest
4+
from authlib.jose import JsonWebKey, jwt
5+
from authlib.jose.errors import (
6+
ExpiredTokenError,
7+
InvalidClaimError,
8+
JoseError,
9+
MissingClaimError,
10+
)
11+
12+
from manage_breast_screening.auth.services import (
13+
DecodeLogoutToken,
14+
InvalidLogoutToken,
15+
)
16+
17+
18+
class TestDecodeLogoutToken:
19+
@staticmethod
20+
def _make_keys(kid: str):
21+
private_jwk = JsonWebKey.generate_key(
22+
"RSA", 2048, is_private=True, options={"kid": kid}
23+
)
24+
public_jwk = private_jwk.as_dict(is_private=False)
25+
private_jwk_dict = private_jwk.as_dict(is_private=True)
26+
return private_jwk_dict, public_jwk
27+
28+
@staticmethod
29+
def _make_token(
30+
private_jwk: dict,
31+
kid: str,
32+
issuer: str,
33+
client_id: str,
34+
overrides: dict | None = None,
35+
) -> str:
36+
now = int(time.time())
37+
payload = {
38+
"iss": issuer,
39+
"aud": client_id,
40+
"iat": now,
41+
"exp": now + 300,
42+
"events": {"http://schemas.openid.net/event/backchannel-logout": {}},
43+
"sub": "user-123",
44+
}
45+
if overrides:
46+
# Apply overrides, removing keys with None values
47+
for key, value in overrides.items():
48+
if value is None:
49+
if key in payload:
50+
del payload[key]
51+
else:
52+
payload[key] = value
53+
54+
headers = {"alg": "RS256", "kid": kid}
55+
token = jwt.encode(headers, payload, private_jwk)
56+
return token.decode("utf-8")
57+
58+
@staticmethod
59+
def _key_loader(public_jwk: dict):
60+
def loader(headers, payload):
61+
return public_jwk
62+
63+
return loader
64+
65+
def test_valid_token_returns_claims(self):
66+
kid = "k1"
67+
issuer = "test-issuer"
68+
client_id = "client-1"
69+
private_jwk, public_jwk = self._make_keys(kid)
70+
token = self._make_token(private_jwk, kid, issuer, client_id)
71+
72+
service = DecodeLogoutToken()
73+
claims = service.call(
74+
metadata={"issuer": issuer},
75+
logout_token=token,
76+
client_id=client_id,
77+
key_loader=self._key_loader(public_jwk),
78+
)
79+
80+
assert claims["iss"] == issuer
81+
assert claims["aud"] == client_id
82+
assert claims["sub"] == "user-123"
83+
assert "http://schemas.openid.net/event/backchannel-logout" in claims["events"]
84+
85+
@pytest.mark.parametrize(
86+
"overrides,expected_error_type,expected_error_text",
87+
[
88+
({"iss": "wrong-issuer"}, InvalidClaimError, "iss"), # invalid issuer
89+
({"aud": "wrong-aud"}, InvalidClaimError, "aud"), # invalid audience
90+
(
91+
{"events": {}},
92+
InvalidClaimError,
93+
"events",
94+
), # missing backchannel-logout event
95+
(
96+
{"events": {"some-other": {}}},
97+
InvalidClaimError,
98+
"events",
99+
), # wrong events
100+
(
101+
{"nonce": "should-be-none"},
102+
InvalidClaimError,
103+
"nonce",
104+
), # nonce must be None
105+
],
106+
)
107+
def test_invalid_claims_raise_error(
108+
self, overrides, expected_error_type, expected_error_text
109+
):
110+
kid = "k1"
111+
issuer = "test-issuer"
112+
client_id = "client-1"
113+
private_jwk, public_jwk = self._make_keys(kid)
114+
token = self._make_token(
115+
private_jwk, kid, issuer, client_id, overrides=overrides
116+
)
117+
118+
service = DecodeLogoutToken()
119+
with pytest.raises(InvalidLogoutToken) as excinfo:
120+
service.call(
121+
metadata={"issuer": issuer},
122+
logout_token=token,
123+
client_id=client_id,
124+
key_loader=self._key_loader(public_jwk),
125+
)
126+
127+
# Assert on the cause type and error message content
128+
assert isinstance(excinfo.value.cause, expected_error_type)
129+
assert expected_error_text in str(excinfo.value.cause)
130+
131+
def test_invalid_signature_raises_error(self):
132+
kid = "k1"
133+
issuer = "test-issuer"
134+
client_id = "client-1"
135+
# Create two different key pairs
136+
private_jwk_1, public_jwk_1 = self._make_keys(kid)
137+
private_jwk_2, _public_jwk_2 = self._make_keys(kid)
138+
# Sign with private_jwk_2 but verify with public_jwk_1 -> invalid signature
139+
token = self._make_token(private_jwk_2, kid, issuer, client_id)
140+
141+
service = DecodeLogoutToken()
142+
with pytest.raises(InvalidLogoutToken) as excinfo:
143+
service.call(
144+
metadata={"issuer": issuer},
145+
logout_token=token,
146+
client_id=client_id,
147+
key_loader=self._key_loader(public_jwk_1),
148+
)
149+
150+
# Invalid signature should raise a JoseError
151+
assert isinstance(excinfo.value.cause, JoseError)
152+
assert "signature" in str(excinfo.value.cause).lower()
153+
154+
def test_expired_token_raises_error(self):
155+
kid = "k1"
156+
issuer = "test-issuer"
157+
client_id = "client-1"
158+
private_jwk, public_jwk = self._make_keys(kid)
159+
now = int(time.time())
160+
token = self._make_token(
161+
private_jwk,
162+
kid,
163+
issuer,
164+
client_id,
165+
overrides={"exp": now - 120}, # already expired beyond leeway
166+
)
167+
168+
service = DecodeLogoutToken()
169+
with pytest.raises(InvalidLogoutToken) as excinfo:
170+
service.call(
171+
metadata={"issuer": issuer},
172+
logout_token=token,
173+
client_id=client_id,
174+
key_loader=self._key_loader(public_jwk),
175+
)
176+
177+
# Expired token should raise an ExpiredTokenError
178+
assert isinstance(excinfo.value.cause, ExpiredTokenError)
179+
assert "expired" in str(excinfo.value.cause).lower()
180+
181+
def test_missing_iat_raises_error(self):
182+
kid = "k1"
183+
issuer = "test-issuer"
184+
client_id = "client-1"
185+
private_jwk, public_jwk = self._make_keys(kid)
186+
# Use overrides to remove the iat claim
187+
token = self._make_token(
188+
private_jwk, kid, issuer, client_id, overrides={"iat": None}
189+
)
190+
191+
service = DecodeLogoutToken()
192+
with pytest.raises(InvalidLogoutToken) as excinfo:
193+
service.call(
194+
metadata={"issuer": issuer},
195+
logout_token=token,
196+
client_id=client_id,
197+
key_loader=self._key_loader(public_jwk),
198+
)
199+
200+
# Missing iat should raise a MissingClaimError
201+
assert isinstance(excinfo.value.cause, MissingClaimError)
202+
assert "iat" in str(excinfo.value.cause).lower()

manage_breast_screening/auth/urls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
# CIS2 OpenID Connect
1212
path("cis2/log-in/", views.cis2_login, name="cis2_login"),
1313
path("cis2/callback/", views.cis2_callback, name="cis2_callback"),
14+
path(
15+
"cis2/back-channel-logout/",
16+
views.cis2_back_channel_logout,
17+
name="cis2_back_channel_logout",
18+
),
1419
# JWKS endpoint for private_key_jwt
1520
path("cis2/jwks_uri", views.jwks, name="jwks"),
1621
]

0 commit comments

Comments
 (0)