Skip to content

Commit ed7169f

Browse files
committed
Simplify decode logout token service
- Implement as a simple function rather than a class - Assume the client id is always the one configured in CIS2_CLIENT_ID, rather than a value injected into the service - Accept the issuer value directly, rather than the entire metadata dict
1 parent a6efd73 commit ed7169f

File tree

3 files changed

+74
-104
lines changed

3 files changed

+74
-104
lines changed

manage_breast_screening/auth/services.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
InvalidTokenError,
99
MissingClaimError,
1010
)
11+
from django.conf import settings
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -16,46 +17,43 @@ class InvalidLogoutToken(Exception):
1617
"""Raised when a CIS2 back-channel logout token is invalid."""
1718

1819

19-
class DecodeLogoutToken:
20-
def call(
21-
self,
22-
metadata: Dict[str, Any],
23-
logout_token: str,
24-
client_id: str,
25-
key_loader,
26-
):
27-
"""
28-
Decode and validate a CIS2 back-channel logout token.
20+
def decode_logout_token(
21+
issuer: str,
22+
key_loader,
23+
logout_token: str,
24+
) -> Dict[str, Any]:
25+
"""
26+
Decode and validate a CIS2 back-channel logout token.
2927
30-
Returns the decoded claims on success and raises InvalidLogoutToken on
31-
any underlying decoding/validation error.
32-
"""
33-
try:
34-
verification_rules = {
35-
"iss": {"values": [metadata["issuer"]], "essential": True},
36-
"aud": {"values": [client_id], "essential": True},
37-
"exp": {"essential": True},
38-
"iat": {"essential": True},
39-
"events": {
40-
"essential": True,
41-
"validate": lambda claim, value: isinstance(value, dict)
42-
and "http://schemas.openid.net/event/backchannel-logout" in value,
43-
},
44-
"nonce": {"validate": lambda claim, value: value is None},
45-
}
46-
claims = jwt.decode(
47-
logout_token,
48-
key=key_loader,
49-
claims_options=verification_rules,
50-
)
51-
claims.validate(leeway=60)
52-
return claims
53-
except (
54-
ExpiredTokenError,
55-
InvalidClaimError,
56-
InvalidTokenError,
57-
JoseError,
58-
MissingClaimError,
59-
) as e:
60-
logger.exception("Invalid logout token")
61-
raise InvalidLogoutToken() from e
28+
Returns the decoded claims on success and raises InvalidLogoutToken on
29+
any underlying decoding/validation error.
30+
"""
31+
try:
32+
verification_rules = {
33+
"iss": {"values": [issuer], "essential": True},
34+
"aud": {"values": [settings.CIS2_CLIENT_ID], "essential": True},
35+
"exp": {"essential": True},
36+
"iat": {"essential": True},
37+
"events": {
38+
"essential": True,
39+
"validate": lambda claim, value: isinstance(value, dict)
40+
and "http://schemas.openid.net/event/backchannel-logout" in value,
41+
},
42+
"nonce": {"validate": lambda claim, value: value is None},
43+
}
44+
claims = jwt.decode(
45+
logout_token,
46+
key=key_loader,
47+
claims_options=verification_rules,
48+
)
49+
claims.validate(leeway=60)
50+
return claims
51+
except (
52+
ExpiredTokenError,
53+
InvalidClaimError,
54+
InvalidTokenError,
55+
JoseError,
56+
MissingClaimError,
57+
) as e:
58+
logger.exception("Invalid logout token")
59+
raise InvalidLogoutToken() from e

manage_breast_screening/auth/tests/test_services.py

Lines changed: 32 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
JoseError,
99
MissingClaimError,
1010
)
11+
from django.conf import settings
1112

1213
from manage_breast_screening.auth.services import (
13-
DecodeLogoutToken,
1414
InvalidLogoutToken,
15+
decode_logout_token,
1516
)
1617

1718

@@ -30,13 +31,12 @@ def _make_token(
3031
private_jwk: dict,
3132
kid: str,
3233
issuer: str,
33-
client_id: str,
3434
overrides: dict | None = None,
3535
) -> str:
3636
now = int(time.time())
3737
payload = {
3838
"iss": issuer,
39-
"aud": client_id,
39+
"aud": settings.CIS2_CLIENT_ID,
4040
"iat": now,
4141
"exp": now + 300,
4242
"events": {"http://schemas.openid.net/event/backchannel-logout": {}},
@@ -57,28 +57,24 @@ def _make_token(
5757

5858
@staticmethod
5959
def _key_loader(public_jwk: dict):
60-
def loader(headers, payload):
60+
def loader(_headers, _payload):
6161
return public_jwk
6262

6363
return loader
6464

6565
def test_valid_token_returns_claims(self):
6666
kid = "k1"
67-
issuer = "test-issuer"
68-
client_id = "client-1"
6967
private_jwk, public_jwk = self._make_keys(kid)
70-
token = self._make_token(private_jwk, kid, issuer, client_id)
71-
72-
service = DecodeLogoutToken()
73-
claims = service.call(
74-
metadata={"issuer": issuer},
75-
logout_token=token,
76-
client_id=client_id,
77-
key_loader=self._key_loader(public_jwk),
68+
token = self._make_token(private_jwk, kid, "test-issuer")
69+
70+
claims = decode_logout_token(
71+
"test-issuer",
72+
self._key_loader(public_jwk),
73+
token,
7874
)
7975

80-
assert claims["iss"] == issuer
81-
assert claims["aud"] == client_id
76+
assert claims["iss"] == "test-issuer"
77+
assert claims["aud"] == settings.CIS2_CLIENT_ID
8278
assert claims["sub"] == "user-123"
8379
assert "http://schemas.openid.net/event/backchannel-logout" in claims["events"]
8480

@@ -108,89 +104,70 @@ def test_invalid_claims_raise_error(
108104
self, overrides, expected_error_type, expected_error_text
109105
):
110106
kid = "k1"
111-
issuer = "test-issuer"
112-
client_id = "client-1"
113107
private_jwk, public_jwk = self._make_keys(kid)
114-
token = self._make_token(
115-
private_jwk, kid, issuer, client_id, overrides=overrides
116-
)
108+
token = self._make_token(private_jwk, kid, "test-issuer", overrides=overrides)
117109

118-
service = DecodeLogoutToken()
119110
with pytest.raises(InvalidLogoutToken) as excinfo:
120-
service.call(
121-
metadata={"issuer": issuer},
122-
logout_token=token,
123-
client_id=client_id,
124-
key_loader=self._key_loader(public_jwk),
111+
decode_logout_token(
112+
"test-issuer",
113+
self._key_loader(public_jwk),
114+
token,
125115
)
126116

127117
assert isinstance(excinfo.value.__cause__, expected_error_type)
128118
assert expected_error_text in str(excinfo.value.__cause__)
129119

130120
def test_invalid_signature_raises_error(self):
131121
kid = "k1"
132-
issuer = "test-issuer"
133-
client_id = "client-1"
134122
# Create two different key pairs
135123
private_jwk_1, public_jwk_1 = self._make_keys(kid)
136124
private_jwk_2, _public_jwk_2 = self._make_keys(kid)
137125
# Sign with private_jwk_2 but verify with public_jwk_1 -> invalid signature
138-
token = self._make_token(private_jwk_2, kid, issuer, client_id)
126+
token = self._make_token(private_jwk_2, kid, "test-issuer")
139127

140-
service = DecodeLogoutToken()
141128
with pytest.raises(InvalidLogoutToken) as excinfo:
142-
service.call(
143-
metadata={"issuer": issuer},
144-
logout_token=token,
145-
client_id=client_id,
146-
key_loader=self._key_loader(public_jwk_1),
129+
decode_logout_token(
130+
"test-issuer",
131+
self._key_loader(public_jwk_1),
132+
token,
147133
)
148134

149135
assert isinstance(excinfo.value.__cause__, JoseError)
150136
assert "signature" in str(excinfo.value.__cause__)
151137

152138
def test_expired_token_raises_error(self):
153139
kid = "k1"
154-
issuer = "test-issuer"
155-
client_id = "client-1"
156140
private_jwk, public_jwk = self._make_keys(kid)
157141
now = int(time.time())
158142
token = self._make_token(
159143
private_jwk,
160144
kid,
161-
issuer,
162-
client_id,
145+
"test-issuer",
163146
overrides={"exp": now - 120}, # already expired beyond leeway
164147
)
165148

166-
service = DecodeLogoutToken()
167149
with pytest.raises(InvalidLogoutToken) as excinfo:
168-
service.call(
169-
metadata={"issuer": issuer},
170-
logout_token=token,
171-
client_id=client_id,
172-
key_loader=self._key_loader(public_jwk),
150+
decode_logout_token(
151+
"test-issuer",
152+
self._key_loader(public_jwk),
153+
token,
173154
)
174155
assert isinstance(excinfo.value.__cause__, ExpiredTokenError)
175156
assert "expired" in str(excinfo.value.__cause__)
176157

177158
def test_missing_iat_raises_error(self):
178159
kid = "k1"
179-
issuer = "test-issuer"
180-
client_id = "client-1"
181160
private_jwk, public_jwk = self._make_keys(kid)
182161
# Use overrides to remove the iat claim
183162
token = self._make_token(
184-
private_jwk, kid, issuer, client_id, overrides={"iat": None}
163+
private_jwk, kid, "test-issuer", overrides={"iat": None}
185164
)
186165

187-
service = DecodeLogoutToken()
188166
with pytest.raises(InvalidLogoutToken) as excinfo:
189-
service.call(
190-
metadata={"issuer": issuer},
191-
logout_token=token,
192-
client_id=client_id,
193-
key_loader=self._key_loader(public_jwk),
167+
decode_logout_token(
168+
"test-issuer",
169+
self._key_loader(public_jwk),
170+
token,
194171
)
195172

196173
assert isinstance(excinfo.value.__cause__, MissingClaimError)

manage_breast_screening/auth/views.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from django.views.decorators.http import require_http_methods
1212

1313
from .oauth import get_cis2_client, jwk_from_public_key
14-
from .services import DecodeLogoutToken, InvalidLogoutToken
14+
from .services import InvalidLogoutToken, decode_logout_token
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -115,12 +115,7 @@ def cis2_back_channel_logout(request):
115115
metadata = client.load_server_metadata()
116116
key_loader = client.create_load_key()
117117
try:
118-
claims = DecodeLogoutToken().call(
119-
metadata=metadata,
120-
logout_token=logout_token,
121-
client_id=client.client_id,
122-
key_loader=key_loader,
123-
)
118+
claims = decode_logout_token(metadata["issuer"], key_loader, logout_token)
124119
except InvalidLogoutToken:
125120
return HttpResponseBadRequest("Invalid logout token")
126121

0 commit comments

Comments
 (0)