Skip to content

Commit 63bf224

Browse files
committed
Add test cases for TokenCache and SerializableTokenCache
1 parent 4a5255c commit 63bf224

File tree

2 files changed

+123
-4
lines changed

2 files changed

+123
-4
lines changed

msal/token_cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def find(self, credential_type, target=None, query=None):
5050
if target else True)
5151
]
5252

53-
def add(self, event):
53+
def add(self, event, now=None):
5454
# type: (dict) -> None
5555
# event typically contains: client_id, scope, token_endpoint,
5656
# resposne, params, data, grant_type
@@ -86,7 +86,7 @@ def add(self, event):
8686
realm or "",
8787
target,
8888
]).lower()
89-
now = time.time()
89+
now = time.time() if now is None else now
9090
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = {
9191
"credential_type": self.CredentialType.ACCESS_TOKEN,
9292
"secret": access_token,
@@ -202,8 +202,8 @@ class SerializableTokenCache(TokenCache):
202202
Indicates whether the cache state has changed since last
203203
:func:`~serialize` or :func:`~deserialize` call.
204204
"""
205-
def add(self, event):
206-
super(SerializableTokenCache, self).add(event)
205+
def add(self, event, **kwargs):
206+
super(SerializableTokenCache, self).add(event, **kwargs)
207207
self.has_state_changed = True
208208

209209
def remove_rt(self, rt_item):

tests/test_token_cache.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import logging
2+
import base64
3+
import json
4+
5+
from msal.token_cache import *
6+
from tests import unittest
7+
8+
9+
logger = logging.getLogger(__name__)
10+
logging.basicConfig(level=logging.DEBUG)
11+
12+
13+
class TokenCacheTestCase(unittest.TestCase):
14+
15+
def setUp(self):
16+
self.cache = TokenCache()
17+
18+
def testAdd(self):
19+
client_info = base64.b64encode(b'''
20+
{"uid": "uid", "utid": "utid"}
21+
''').decode('utf-8')
22+
id_token = "header.%s.signature" % base64.b64encode(b'''{
23+
"sub": "subject",
24+
"oid": "object1234",
25+
"preferred_username": "John Doe"
26+
}''').decode('utf-8')
27+
self.cache.add({
28+
"client_id": "my_client_id",
29+
"scope": ["s2", "s1", "s3"], # Not in particular order
30+
"token_endpoint": "https://login.example.com/contoso/v2/token",
31+
"response": {
32+
"access_token": "an access token",
33+
"token_type": "some type",
34+
"expires_in": 3600,
35+
"refresh_token": "a refresh token",
36+
"client_info": client_info,
37+
"id_token": id_token,
38+
},
39+
}, now=1000)
40+
self.assertEqual(
41+
{
42+
'cached_at': 1000,
43+
'client_id': 'my_client_id',
44+
'credential_type': 'AccessToken',
45+
'environment': 'login.example.com',
46+
'expires_on': 4600,
47+
'extended_expires_on': 1000,
48+
'home_account_id': "uid.utid",
49+
'realm': 'contoso',
50+
'secret': 'an access token',
51+
'target': 's2 s1 s3',
52+
},
53+
self.cache._cache["AccessToken"].get(
54+
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3')
55+
)
56+
self.assertEqual(
57+
{
58+
'client_id': 'my_client_id',
59+
'credential_type': 'RefreshToken',
60+
'environment': 'login.example.com',
61+
'home_account_id': "uid.utid",
62+
'secret': 'a refresh token',
63+
'target': 's2 s1 s3',
64+
},
65+
self.cache._cache["RefreshToken"].get(
66+
'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3')
67+
)
68+
self.assertEqual(
69+
{
70+
'home_account_id': "uid.utid",
71+
'environment': 'login.example.com',
72+
'realm': 'contoso',
73+
'local_account_id': "object1234",
74+
'username': "John Doe",
75+
'authority_type': "AAD",
76+
},
77+
self.cache._cache["Account"].get('uid.utid-login.example.com-contoso')
78+
)
79+
self.assertEqual(
80+
{
81+
'credential_type': 'IdToken',
82+
'secret': id_token,
83+
'home_account_id': "uid.utid",
84+
'environment': 'login.example.com',
85+
'realm': 'contoso',
86+
'client_id': 'my_client_id',
87+
},
88+
self.cache._cache["IdToken"].get(
89+
'uid.utid-login.example.com-idtoken-my_client_id-contoso')
90+
)
91+
92+
93+
class SerializableTokenCacheTestCase(TokenCacheTestCase):
94+
# Run all inherited test methods, and have extra check in tearDown()
95+
96+
def setUp(self):
97+
self.cache = SerializableTokenCache()
98+
self.cache.deserialize("""
99+
{
100+
"AccessToken": {
101+
"an-entry": {
102+
"foo": "bar"
103+
}
104+
},
105+
"customized": "whatever"
106+
}
107+
""")
108+
109+
def tearDown(self):
110+
state = self.cache.serialize()
111+
logger.debug("serialize() = %s", state)
112+
# Now assert all extended content are kept intact
113+
output = json.loads(state)
114+
self.assertEqual(output.get("customized"), "whatever",
115+
"Undefined cache keys and their values should be intact")
116+
self.assertEqual(
117+
output.get("AccessToken", {}).get("an-entry"), {"foo": "bar"},
118+
"Undefined token keys and their values should be intact")
119+

0 commit comments

Comments
 (0)