33import json
44import time
55
6- from msal .token_cache import *
6+ from msal .token_cache import TokenCache , SerializableTokenCache
77from tests import unittest
88
99
@@ -51,11 +51,14 @@ class TokenCacheTestCase(unittest.TestCase):
5151
5252 def setUp (self ):
5353 self .cache = TokenCache ()
54+ self .at_key_maker = self .cache .key_makers [
55+ TokenCache .CredentialType .ACCESS_TOKEN ]
5456
5557 def testAddByAad (self ):
5658 client_id = "my_client_id"
5759 id_token = build_id_token (
5860 oid = "object1234" , preferred_username = "John Doe" , aud = client_id )
61+ now = 1000
5962 self .cache .add ({
6063 "client_id" : client_id ,
6164 "scope" : ["s2" , "s1" , "s3" ], # Not in particular order
@@ -64,7 +67,7 @@ def testAddByAad(self):
6467 uid = "uid" , utid = "utid" , # client_info
6568 expires_in = 3600 , access_token = "an access token" ,
6669 id_token = id_token , refresh_token = "a refresh token" ),
67- }, now = 1000 )
70+ }, now = now )
6871 access_token_entry = {
6972 'cached_at' : "1000" ,
7073 'client_id' : 'my_client_id' ,
@@ -78,14 +81,11 @@ def testAddByAad(self):
7881 'target' : 's1 s2 s3' , # Sorted
7982 'token_type' : 'some type' ,
8083 }
81- self .assertEqual (
82- access_token_entry ,
83- self .cache ._cache ["AccessToken" ].get (
84- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' )
85- )
84+ self .assertEqual (access_token_entry , self .cache ._cache ["AccessToken" ].get (
85+ self .at_key_maker (** access_token_entry )))
8686 self .assertIn (
8787 access_token_entry ,
88- self .cache .find (self .cache .CredentialType .ACCESS_TOKEN ),
88+ self .cache .find (self .cache .CredentialType .ACCESS_TOKEN , now = now ),
8989 "find(..., query=None) should not crash, even though MSAL does not use it" )
9090 self .assertEqual (
9191 {
@@ -144,8 +144,7 @@ def testAddByAdfs(self):
144144 expires_in = 3600 , access_token = "an access token" ,
145145 id_token = id_token , refresh_token = "a refresh token" ),
146146 }, now = 1000 )
147- self .assertEqual (
148- {
147+ access_token_entry = {
149148 'cached_at' : "1000" ,
150149 'client_id' : 'my_client_id' ,
151150 'credential_type' : 'AccessToken' ,
@@ -157,10 +156,9 @@ def testAddByAdfs(self):
157156 'secret' : 'an access token' ,
158157 'target' : 's1 s2 s3' , # Sorted
159158 'token_type' : 'some type' ,
160- },
161- self .cache ._cache ["AccessToken" ].get (
162- 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3' )
163- )
159+ }
160+ self .assertEqual (access_token_entry , self .cache ._cache ["AccessToken" ].get (
161+ self .at_key_maker (** access_token_entry )))
164162 self .assertEqual (
165163 {
166164 'client_id' : 'my_client_id' ,
@@ -206,37 +204,67 @@ def testAddByAdfs(self):
206204 "appmetadata-fs.msidlab8.com-my_client_id" )
207205 )
208206
209- def test_key_id_is_also_recorded (self ):
210- my_key_id = "some_key_id_123"
207+ def assertFoundAccessToken (self , * , scopes , query , data = None , now = None ):
208+ cached_at = None
209+ for cached_at in self .cache .search (
210+ TokenCache .CredentialType .ACCESS_TOKEN ,
211+ target = scopes , query = query , now = now ,
212+ ):
213+ for k , v in (data or {}).items (): # The extra data, if any
214+ self .assertEqual (cached_at .get (k ), v , f"AT should contain { k } ={ v } " )
215+ self .assertTrue (cached_at , "AT should be cached and searchable" )
216+ return cached_at
217+
218+ def _test_data_should_be_saved_and_searchable_in_access_token (self , data ):
219+ scopes = ["s2" , "s1" , "s3" ] # Not in particular order
220+ now = 1000
211221 self .cache .add ({
212- "data" : { "key_id" : my_key_id } ,
222+ "data" : data ,
213223 "client_id" : "my_client_id" ,
214- "scope" : [ "s2" , "s1" , "s3" ], # Not in particular order
224+ "scope" : scopes ,
215225 "token_endpoint" : "https://login.example.com/contoso/v2/token" ,
216226 "response" : build_response (
217227 uid = "uid" , utid = "utid" , # client_info
218228 expires_in = 3600 , access_token = "an access token" ,
219229 refresh_token = "a refresh token" ),
220- }, now = 1000 )
221- cached_key_id = self .cache ._cache ["AccessToken" ].get (
222- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' ,
223- {}).get ("key_id" )
224- self .assertEqual (my_key_id , cached_key_id , "AT should be bound to the key" )
230+ }, now = now )
231+ self .assertFoundAccessToken (scopes = scopes , data = data , now = now , query = dict (
232+ data , # Also use the extra data as a query criteria
233+ client_id = "my_client_id" ,
234+ environment = "login.example.com" ,
235+ realm = "contoso" ,
236+ home_account_id = "uid.utid" ,
237+ ))
238+
239+ def test_extra_data_should_also_be_recorded_and_searchable_in_access_token (self ):
240+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "1" })
241+
242+ def test_access_tokens_with_different_key_id (self ):
243+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "1" })
244+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "2" })
245+ self .assertEqual (
246+ len (self .cache ._cache ["AccessToken" ]),
247+ 1 , """Historically, tokens are not keyed by key_id,
248+ so a new token overwrites the old one, and we would end up with 1 token in cache""" )
225249
226250 def test_refresh_in_should_be_recorded_as_refresh_on (self ): # Sounds weird. Yep.
251+ scopes = ["s2" , "s1" , "s3" ] # Not in particular order
227252 self .cache .add ({
228253 "client_id" : "my_client_id" ,
229- "scope" : [ "s2" , "s1" , "s3" ], # Not in particular order
254+ "scope" : scopes ,
230255 "token_endpoint" : "https://login.example.com/contoso/v2/token" ,
231256 "response" : build_response (
232257 uid = "uid" , utid = "utid" , # client_info
233258 expires_in = 3600 , refresh_in = 1800 , access_token = "an access token" ,
234259 ), #refresh_token="a refresh token"),
235260 }, now = 1000 )
236- refresh_on = self .cache ._cache ["AccessToken" ].get (
237- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' ,
238- {}).get ("refresh_on" )
239- self .assertEqual ("2800" , refresh_on , "Should save refresh_on" )
261+ at = self .assertFoundAccessToken (scopes = scopes , query = dict (
262+ client_id = "my_client_id" ,
263+ environment = "login.example.com" ,
264+ realm = "contoso" ,
265+ home_account_id = "uid.utid" ,
266+ ))
267+ self .assertEqual ("2800" , at .get ("refresh_on" ), "Should save refresh_on" )
240268
241269 def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt (self ):
242270 sample = {
@@ -258,7 +286,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
258286 )
259287
260288
261- class SerializableTokenCacheTestCase (TokenCacheTestCase ):
289+ class SerializableTokenCacheTestCase (unittest . TestCase ):
262290 # Run all inherited test methods, and have extra check in tearDown()
263291
264292 def setUp (self ):
0 commit comments