33import json
44import time
55
6- from msal .token_cache import *
6+ from msal .token_cache import TokenCache , SerializableTokenCache
77from tests import unittest
88
99
@@ -51,6 +51,8 @@ class TokenCacheTestCase(unittest.TestCase):
5151
5252 def setUp (self ):
5353 self .cache = TokenCache ()
54+ self .at_key_maker = self .cache .key_makers [
55+ TokenCache .CredentialType .ACCESS_TOKEN ]
5456
5557 def testAddByAad (self ):
5658 client_id = "my_client_id"
@@ -78,11 +80,8 @@ def testAddByAad(self):
7880 'target' : 's1 s2 s3' , # Sorted
7981 'token_type' : 'some type' ,
8082 }
81- self .assertEqual (
82- access_token_entry ,
83- self .cache ._cache ["AccessToken" ].get (
84- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' )
85- )
83+ self .assertEqual (access_token_entry , self .cache ._cache ["AccessToken" ].get (
84+ self .at_key_maker (** access_token_entry )))
8685 self .assertIn (
8786 access_token_entry ,
8887 self .cache .find (self .cache .CredentialType .ACCESS_TOKEN ),
@@ -144,8 +143,7 @@ def testAddByAdfs(self):
144143 expires_in = 3600 , access_token = "an access token" ,
145144 id_token = id_token , refresh_token = "a refresh token" ),
146145 }, now = 1000 )
147- self .assertEqual (
148- {
146+ access_token_entry = {
149147 'cached_at' : "1000" ,
150148 'client_id' : 'my_client_id' ,
151149 'credential_type' : 'AccessToken' ,
@@ -157,10 +155,9 @@ def testAddByAdfs(self):
157155 'secret' : 'an access token' ,
158156 'target' : 's1 s2 s3' , # Sorted
159157 'token_type' : 'some type' ,
160- },
161- self .cache ._cache ["AccessToken" ].get (
162- 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3' )
163- )
158+ }
159+ self .assertEqual (access_token_entry , self .cache ._cache ["AccessToken" ].get (
160+ self .at_key_maker (** access_token_entry )))
164161 self .assertEqual (
165162 {
166163 'client_id' : 'my_client_id' ,
@@ -238,37 +235,32 @@ def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
238235 def test_extra_data_should_also_be_recorded_and_searchable_in_access_token (self ):
239236 self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "1" })
240237
241- def test_key_id_is_also_recorded (self ):
242- my_key_id = "some_key_id_123"
243- self .cache .add ({
244- "data" : {"key_id" : my_key_id },
245- "client_id" : "my_client_id" ,
246- "scope" : ["s2" , "s1" , "s3" ], # Not in particular order
247- "token_endpoint" : "https://login.example.com/contoso/v2/token" ,
248- "response" : build_response (
249- uid = "uid" , utid = "utid" , # client_info
250- expires_in = 3600 , access_token = "an access token" ,
251- refresh_token = "a refresh token" ),
252- }, now = 1000 )
253- cached_key_id = self .cache ._cache ["AccessToken" ].get (
254- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' ,
255- {}).get ("key_id" )
256- self .assertEqual (my_key_id , cached_key_id , "AT should be bound to the key" )
238+ def test_access_tokens_with_different_key_id (self ):
239+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "1" })
240+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "2" })
241+ self .assertEqual (
242+ len (self .cache ._cache ["AccessToken" ]),
243+ 1 , """Historically, tokens are not keyed by key_id,
244+ so a new token overwrites the old one, and we would end up with 1 token in cache""" )
257245
258246 def test_refresh_in_should_be_recorded_as_refresh_on (self ): # Sounds weird. Yep.
247+ scopes = ["s2" , "s1" , "s3" ] # Not in particular order
259248 self .cache .add ({
260249 "client_id" : "my_client_id" ,
261- "scope" : [ "s2" , "s1" , "s3" ], # Not in particular order
250+ "scope" : scopes ,
262251 "token_endpoint" : "https://login.example.com/contoso/v2/token" ,
263252 "response" : build_response (
264253 uid = "uid" , utid = "utid" , # client_info
265254 expires_in = 3600 , refresh_in = 1800 , access_token = "an access token" ,
266255 ), #refresh_token="a refresh token"),
267256 }, now = 1000 )
268- refresh_on = self .cache ._cache ["AccessToken" ].get (
269- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' ,
270- {}).get ("refresh_on" )
271- self .assertEqual ("2800" , refresh_on , "Should save refresh_on" )
257+ at = self .assertFoundAccessToken (scopes = scopes , query = dict (
258+ client_id = "my_client_id" ,
259+ environment = "login.example.com" ,
260+ realm = "contoso" ,
261+ home_account_id = "uid.utid" ,
262+ ))
263+ self .assertEqual ("2800" , at .get ("refresh_on" ), "Should save refresh_on" )
272264
273265 def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt (self ):
274266 sample = {
0 commit comments