Skip to content

Commit 0108c56

Browse files
committed
poetry jwt
1 parent d39797f commit 0108c56

File tree

2 files changed

+193
-3
lines changed

2 files changed

+193
-3
lines changed

lambdas/shared/poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import base64
2+
import json
3+
import responses
4+
import time
5+
import unittest
6+
from responses import matchers
7+
from unittest.mock import MagicMock, patch, ANY
8+
9+
from common.authentication import AppRestrictedAuth, Service
10+
from common.models.errors import UnhandledResponseError
11+
12+
13+
class TestAuthenticator(unittest.TestCase):
14+
def setUp(self):
15+
self.kid = "a_kid"
16+
self.api_key = "an_api_key"
17+
self.private_key = "a_private_key"
18+
# The private key must be stored as base64 encoded in secret-manager
19+
b64_private_key = base64.b64encode(self.private_key.encode()).decode()
20+
21+
pds_secret = {"private_key_b64": b64_private_key, "kid": self.kid, "api_key": self.api_key}
22+
secret_response = {"SecretString": json.dumps(pds_secret)}
23+
24+
self.secret_manager_client = MagicMock()
25+
self.secret_manager_client.get_secret_value.return_value = secret_response
26+
27+
self.cache = MagicMock()
28+
self.cache.get.return_value = None
29+
30+
env = "an-env"
31+
self.authenticator = AppRestrictedAuth(Service.PDS, self.secret_manager_client, env, self.cache)
32+
self.url = f"https://{env}.api.service.nhs.uk/oauth2/token"
33+
34+
@responses.activate
35+
def test_post_request_to_token(self):
36+
"""it should send a POST request to oauth2 service"""
37+
_jwt = "a-jwt"
38+
request_data = {
39+
'grant_type': 'client_credentials',
40+
'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
41+
'client_assertion': _jwt
42+
}
43+
access_token = "an-access-token"
44+
responses.add(responses.POST, self.url, status=200, json={"access_token": access_token},
45+
match=[matchers.urlencoded_params_matcher(request_data)])
46+
47+
with patch("common.authentication.jwt.encode") as mock_jwt:
48+
mock_jwt.return_value = _jwt
49+
# When
50+
act_access_token = self.authenticator.get_access_token()
51+
52+
# Then
53+
self.assertEqual(act_access_token, access_token)
54+
55+
@responses.activate
56+
def test_jwt_values(self):
57+
"""it should send correct claims and header"""
58+
claims = {
59+
"iss": self.api_key,
60+
"sub": self.api_key,
61+
"aud": self.url,
62+
"iat": ANY,
63+
"exp": ANY,
64+
"jti": ANY
65+
}
66+
_jwt = "a-jwt"
67+
access_token = "an-access-token"
68+
69+
responses.add(responses.POST, self.url, status=200, json={"access_token": access_token})
70+
71+
with patch("jwt.encode") as mock_jwt:
72+
mock_jwt.return_value = _jwt
73+
# When
74+
self.authenticator.get_access_token()
75+
# Then
76+
mock_jwt.assert_called_once_with(claims, self.private_key,
77+
algorithm="RS512", headers={"kid": self.kid})
78+
79+
def test_env_mapping(self):
80+
"""it should target int environment for none-prod environment, otherwise int"""
81+
# For env=none-prod
82+
env = "some-env"
83+
auth = AppRestrictedAuth(Service.PDS, None, env, None)
84+
self.assertTrue(auth.token_url.startswith(f"https://{env}."))
85+
86+
# For env=prod
87+
env = "prod"
88+
auth = AppRestrictedAuth(Service.PDS, None, env, None)
89+
self.assertTrue(env not in auth.token_url)
90+
91+
def test_returned_cached_token(self):
92+
"""it should return cached token"""
93+
cached_token = {
94+
"token": "a-cached-access-token",
95+
"expires_at": int(time.time()) + 99999 # make sure it's not expired
96+
}
97+
self.cache.get.return_value = cached_token
98+
99+
# When
100+
token = self.authenticator.get_access_token()
101+
102+
# Then
103+
self.assertEqual(token, cached_token["token"])
104+
self.secret_manager_client.assert_not_called()
105+
106+
@responses.activate
107+
def test_update_cache(self):
108+
"""it should update cached token"""
109+
self.cache.get.return_value = None
110+
token = "a-new-access-token"
111+
cached_token = {
112+
"token": token,
113+
"expires_at": ANY
114+
}
115+
responses.add(responses.POST, self.url, status=200, json={"access_token": token})
116+
117+
with patch("jwt.encode") as mock_jwt:
118+
mock_jwt.return_value = "a-jwt"
119+
# When
120+
self.authenticator.get_access_token()
121+
122+
# Then
123+
self.cache.put.assert_called_once_with(f"{Service.PDS.value}_access_token", cached_token)
124+
125+
@responses.activate
126+
def test_expired_token_in_cache(self):
127+
"""it should not return cached access token if it's expired"""
128+
now_epoch = 12345
129+
expires_at = now_epoch + self.authenticator.expiry
130+
cached_token = {
131+
"token": "an-expired-cached-access-token",
132+
"expires_at": expires_at,
133+
}
134+
self.cache.get.return_value = cached_token
135+
136+
new_token = "a-new-token"
137+
responses.add(responses.POST, self.url, status=200, json={"access_token": new_token})
138+
139+
new_now = expires_at # this is to trigger expiry and also the mocked now-time when storing the new token
140+
with patch("common.authentication.jwt.encode") as mock_jwt:
141+
with patch("time.time") as mock_time:
142+
mock_time.return_value = new_now
143+
mock_jwt.return_value = "a-jwt"
144+
# When
145+
self.authenticator.get_access_token()
146+
147+
# Then
148+
exp_cached_token = {
149+
"token": new_token,
150+
"expires_at": new_now + self.authenticator.expiry
151+
}
152+
self.cache.put.assert_called_once_with(ANY, exp_cached_token)
153+
154+
@responses.activate
155+
def test_uses_cache_for_token(self):
156+
"""it should use the cache for the `Service` auth call"""
157+
158+
token = "a-new-access-token"
159+
token_call = responses.add(responses.POST, self.url, status=200, json={"access_token": token})
160+
values = {}
161+
162+
def get_side_effect(key):
163+
return values.get(key, None)
164+
165+
def put_side_effect(key, value):
166+
values[key] = value
167+
168+
self.cache.get.side_effect = get_side_effect
169+
self.cache.put.side_effect = put_side_effect
170+
171+
with patch("common.authentication.jwt.encode") as mock_jwt:
172+
mock_jwt.return_value = "a-jwt"
173+
# When
174+
self.assertEqual(0, token_call.call_count)
175+
self.authenticator.get_access_token()
176+
self.assertEqual(1, token_call.call_count)
177+
self.authenticator.get_access_token()
178+
self.assertEqual(1, token_call.call_count)
179+
180+
@responses.activate
181+
def test_raise_exception(self):
182+
"""it should raise exception if auth response is not 200"""
183+
self.cache.get.return_value = None
184+
responses.add(responses.POST, self.url, status=400)
185+
186+
with patch("common.authentication.jwt.encode") as mock_jwt:
187+
mock_jwt.return_value = "a-jwt"
188+
with self.assertRaises(UnhandledResponseError):
189+
# When
190+
self.authenticator.get_access_token()

0 commit comments

Comments
 (0)