Skip to content

Commit f41b5bf

Browse files
authored
Merge pull request #370 from NHSDigital/10291-logout-cis2--II
10291 CIS2 refactors and improvements
2 parents 7f72ae7 + b6518b5 commit f41b5bf

File tree

12 files changed

+253
-203
lines changed

12 files changed

+253
-203
lines changed

Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ ARG poetry_version
2020

2121
WORKDIR /app
2222

23+
RUN apk add --no-cache libgcc libstdc++ build-base
24+
2325
# Set environment variables
2426
ENV PYTHONDONTWRITEBYTECODE=1
2527
ENV PYTHONUNBUFFERED=1

docs/cis2_auth.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ The following settings in `manage_breast_screening/config/settings.py` provide t
4040

4141
- **`CIS2_SERVER_METADATA_URL`** - URL to CIS2's OpenID Connect discovery document
4242
- **`CIS2_CLIENT_ID`** - Client identifier registered with CIS2
43-
- **`CIS2_PRIVATE_KEY`** - RSA private key in PEM format for JWT signing (supports `\n` escaped newlines)
44-
- **`CIS2_PUBLIC_KEY`** - Corresponding RSA public key in PEM format (supports `\n` escaped newlines)
43+
- **`CIS2_CLIENT_PRIVATE_KEY`** - RSA private key in PEM format for JWT signing (supports `\n` escaped newlines)
44+
- **`CIS2_CLIENT_PUBLIC_KEY`** - Corresponding RSA public key in PEM format (supports `\n` escaped newlines)
4545
- **`CIS2_SCOPES`** - OAuth scopes requested
4646

4747
The private and public keys form a keypair used for the `private_key_jwt` client authentication method.

manage_breast_screening/auth/oauth.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,27 @@ def get_cis2_client():
1515
name
1616
for name in [
1717
"CIS2_CLIENT_ID",
18-
"CIS2_PRIVATE_KEY",
19-
"CIS2_PUBLIC_KEY",
18+
"CIS2_CLIENT_PRIVATE_KEY",
19+
"CIS2_CLIENT_PUBLIC_KEY",
2020
"CIS2_SERVER_METADATA_URL",
2121
]
2222
if not getattr(settings, name, None)
2323
]
2424
if missing:
2525
raise ValueError(f"Missing required CIS2 OAuth settings: {', '.join(missing)}")
2626

27+
# Return existing client if already registered
28+
client = oauth._clients.get("cis2")
29+
if client:
30+
return client
31+
2732
jwk = jwk_from_public_key()
2833
kid = jwk.thumbprint()
2934

3035
client = oauth.register(
3136
"cis2",
3237
client_id=settings.CIS2_CLIENT_ID,
33-
client_secret=settings.CIS2_PRIVATE_KEY,
38+
client_secret=settings.CIS2_CLIENT_PRIVATE_KEY,
3439
server_metadata_url=settings.CIS2_SERVER_METADATA_URL,
3540
client_kwargs={
3641
"scope": settings.CIS2_SCOPES,
@@ -68,7 +73,7 @@ def jwk_from_public_key():
6873
Returns:
6974
JsonWebKey | None: Public JWK or None on failure.
7075
"""
71-
jwk = JsonWebKey.import_key(settings.CIS2_PUBLIC_KEY, {"kty": "RSA"})
76+
jwk = JsonWebKey.import_key(settings.CIS2_CLIENT_PUBLIC_KEY, {"kty": "RSA"})
7277
return jwk
7378

7479

manage_breast_screening/auth/services.py

Lines changed: 41 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,64 +8,52 @@
88
InvalidTokenError,
99
MissingClaimError,
1010
)
11+
from django.conf import settings
1112

1213
logger = logging.getLogger(__name__)
1314

1415

1516
class InvalidLogoutToken(Exception):
1617
"""Raised when a CIS2 back-channel logout token is invalid."""
1718

18-
def __init__(self, cause=None):
19-
"""
20-
Initialize with optional cause.
2119

22-
Args:
23-
cause: The original exception that caused this error
24-
"""
25-
self.cause = cause
26-
super().__init__(str(cause) if cause else "Invalid logout token")
27-
28-
29-
class DecodeLogoutToken:
30-
def call(
31-
self,
32-
metadata: Dict[str, Any],
33-
logout_token: str,
34-
client_id: str,
35-
key_loader,
36-
):
37-
"""
38-
Decode and validate a CIS2 back-channel logout token.
39-
40-
Returns the decoded claims on success and raises InvalidLogoutToken on
41-
any underlying decoding/validation error.
42-
"""
43-
try:
44-
verification_rules = {
45-
"iss": {"values": [metadata["issuer"]], "essential": True},
46-
"aud": {"values": [client_id], "essential": True},
47-
"exp": {"essential": True},
48-
"iat": {"essential": True},
49-
"events": {
50-
"essential": True,
51-
"validate": lambda claim, value: isinstance(value, dict)
52-
and "http://schemas.openid.net/event/backchannel-logout" in value,
53-
},
54-
"nonce": {"validate": lambda claim, value: value is None},
55-
}
56-
claims = jwt.decode(
57-
logout_token,
58-
key=key_loader,
59-
claims_options=verification_rules,
60-
)
61-
claims.validate(leeway=60)
62-
return claims
63-
except (
64-
ExpiredTokenError,
65-
InvalidClaimError,
66-
InvalidTokenError,
67-
JoseError,
68-
MissingClaimError,
69-
) as e:
70-
logger.exception("Invalid logout token")
71-
raise InvalidLogoutToken(cause=e) from e
20+
def decode_logout_token(
21+
issuer: str,
22+
key_loader,
23+
logout_token: str,
24+
) -> Dict[str, Any]:
25+
"""
26+
Decode and validate a CIS2 back-channel logout token.
27+
28+
Returns the decoded claims on success and raises InvalidLogoutToken on
29+
any underlying decoding/validation error.
30+
"""
31+
try:
32+
verification_rules = {
33+
"iss": {"values": [issuer], "essential": True},
34+
"aud": {"values": [settings.CIS2_CLIENT_ID], "essential": True},
35+
"exp": {"essential": True},
36+
"iat": {"essential": True},
37+
"events": {
38+
"essential": True,
39+
"validate": lambda claim, value: isinstance(value, dict)
40+
and "http://schemas.openid.net/event/backchannel-logout" in value,
41+
},
42+
"nonce": {"validate": lambda claim, value: value is None},
43+
}
44+
claims = jwt.decode(
45+
logout_token,
46+
key=key_loader,
47+
claims_options=verification_rules,
48+
)
49+
claims.validate(leeway=60)
50+
return claims
51+
except (
52+
ExpiredTokenError,
53+
InvalidClaimError,
54+
InvalidTokenError,
55+
JoseError,
56+
MissingClaimError,
57+
) as e:
58+
logger.exception("Invalid logout token")
59+
raise InvalidLogoutToken() from e

manage_breast_screening/auth/tests/test_services.py

Lines changed: 40 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
JoseError,
99
MissingClaimError,
1010
)
11+
from django.conf import settings
1112

1213
from manage_breast_screening.auth.services import (
13-
DecodeLogoutToken,
1414
InvalidLogoutToken,
15+
decode_logout_token,
1516
)
1617

1718

@@ -30,13 +31,12 @@ def _make_token(
3031
private_jwk: dict,
3132
kid: str,
3233
issuer: str,
33-
client_id: str,
3434
overrides: dict | None = None,
3535
) -> str:
3636
now = int(time.time())
3737
payload = {
3838
"iss": issuer,
39-
"aud": client_id,
39+
"aud": settings.CIS2_CLIENT_ID,
4040
"iat": now,
4141
"exp": now + 300,
4242
"events": {"http://schemas.openid.net/event/backchannel-logout": {}},
@@ -57,28 +57,24 @@ def _make_token(
5757

5858
@staticmethod
5959
def _key_loader(public_jwk: dict):
60-
def loader(headers, payload):
60+
def loader(_headers, _payload):
6161
return public_jwk
6262

6363
return loader
6464

6565
def test_valid_token_returns_claims(self):
6666
kid = "k1"
67-
issuer = "test-issuer"
68-
client_id = "client-1"
6967
private_jwk, public_jwk = self._make_keys(kid)
70-
token = self._make_token(private_jwk, kid, issuer, client_id)
71-
72-
service = DecodeLogoutToken()
73-
claims = service.call(
74-
metadata={"issuer": issuer},
75-
logout_token=token,
76-
client_id=client_id,
77-
key_loader=self._key_loader(public_jwk),
68+
token = self._make_token(private_jwk, kid, "test-issuer")
69+
70+
claims = decode_logout_token(
71+
"test-issuer",
72+
self._key_loader(public_jwk),
73+
token,
7874
)
7975

80-
assert claims["iss"] == issuer
81-
assert claims["aud"] == client_id
76+
assert claims["iss"] == "test-issuer"
77+
assert claims["aud"] == settings.CIS2_CLIENT_ID
8278
assert claims["sub"] == "user-123"
8379
assert "http://schemas.openid.net/event/backchannel-logout" in claims["events"]
8480

@@ -108,95 +104,71 @@ def test_invalid_claims_raise_error(
108104
self, overrides, expected_error_type, expected_error_text
109105
):
110106
kid = "k1"
111-
issuer = "test-issuer"
112-
client_id = "client-1"
113107
private_jwk, public_jwk = self._make_keys(kid)
114-
token = self._make_token(
115-
private_jwk, kid, issuer, client_id, overrides=overrides
116-
)
108+
token = self._make_token(private_jwk, kid, "test-issuer", overrides=overrides)
117109

118-
service = DecodeLogoutToken()
119110
with pytest.raises(InvalidLogoutToken) as excinfo:
120-
service.call(
121-
metadata={"issuer": issuer},
122-
logout_token=token,
123-
client_id=client_id,
124-
key_loader=self._key_loader(public_jwk),
111+
decode_logout_token(
112+
"test-issuer",
113+
self._key_loader(public_jwk),
114+
token,
125115
)
126116

127-
# Assert on the cause type and error message content
128-
assert isinstance(excinfo.value.cause, expected_error_type)
129-
assert expected_error_text in str(excinfo.value.cause)
117+
assert isinstance(excinfo.value.__cause__, expected_error_type)
118+
assert expected_error_text in str(excinfo.value.__cause__)
130119

131120
def test_invalid_signature_raises_error(self):
132121
kid = "k1"
133-
issuer = "test-issuer"
134-
client_id = "client-1"
135122
# Create two different key pairs
136123
private_jwk_1, public_jwk_1 = self._make_keys(kid)
137124
private_jwk_2, _public_jwk_2 = self._make_keys(kid)
138125
# Sign with private_jwk_2 but verify with public_jwk_1 -> invalid signature
139-
token = self._make_token(private_jwk_2, kid, issuer, client_id)
126+
token = self._make_token(private_jwk_2, kid, "test-issuer")
140127

141-
service = DecodeLogoutToken()
142128
with pytest.raises(InvalidLogoutToken) as excinfo:
143-
service.call(
144-
metadata={"issuer": issuer},
145-
logout_token=token,
146-
client_id=client_id,
147-
key_loader=self._key_loader(public_jwk_1),
129+
decode_logout_token(
130+
"test-issuer",
131+
self._key_loader(public_jwk_1),
132+
token,
148133
)
149134

150-
# Invalid signature should raise a JoseError
151-
assert isinstance(excinfo.value.cause, JoseError)
152-
assert "signature" in str(excinfo.value.cause).lower()
135+
assert isinstance(excinfo.value.__cause__, JoseError)
136+
assert "signature" in str(excinfo.value.__cause__)
153137

154138
def test_expired_token_raises_error(self):
155139
kid = "k1"
156-
issuer = "test-issuer"
157-
client_id = "client-1"
158140
private_jwk, public_jwk = self._make_keys(kid)
159141
now = int(time.time())
160142
token = self._make_token(
161143
private_jwk,
162144
kid,
163-
issuer,
164-
client_id,
145+
"test-issuer",
165146
overrides={"exp": now - 120}, # already expired beyond leeway
166147
)
167148

168-
service = DecodeLogoutToken()
169149
with pytest.raises(InvalidLogoutToken) as excinfo:
170-
service.call(
171-
metadata={"issuer": issuer},
172-
logout_token=token,
173-
client_id=client_id,
174-
key_loader=self._key_loader(public_jwk),
150+
decode_logout_token(
151+
"test-issuer",
152+
self._key_loader(public_jwk),
153+
token,
175154
)
176-
177-
# Expired token should raise an ExpiredTokenError
178-
assert isinstance(excinfo.value.cause, ExpiredTokenError)
179-
assert "expired" in str(excinfo.value.cause).lower()
155+
assert isinstance(excinfo.value.__cause__, ExpiredTokenError)
156+
assert "expired" in str(excinfo.value.__cause__)
180157

181158
def test_missing_iat_raises_error(self):
182159
kid = "k1"
183-
issuer = "test-issuer"
184-
client_id = "client-1"
185160
private_jwk, public_jwk = self._make_keys(kid)
186161
# Use overrides to remove the iat claim
187162
token = self._make_token(
188-
private_jwk, kid, issuer, client_id, overrides={"iat": None}
163+
private_jwk, kid, "test-issuer", overrides={"iat": None}
189164
)
190165

191-
service = DecodeLogoutToken()
192166
with pytest.raises(InvalidLogoutToken) as excinfo:
193-
service.call(
194-
metadata={"issuer": issuer},
195-
logout_token=token,
196-
client_id=client_id,
197-
key_loader=self._key_loader(public_jwk),
167+
decode_logout_token(
168+
"test-issuer",
169+
self._key_loader(public_jwk),
170+
token,
198171
)
199172

200-
# Missing iat should raise a MissingClaimError
201-
assert isinstance(excinfo.value.cause, MissingClaimError)
202-
assert "iat" in str(excinfo.value.cause).lower()
173+
assert isinstance(excinfo.value.__cause__, MissingClaimError)
174+
assert "iat" in str(excinfo.value.__cause__)

0 commit comments

Comments
 (0)