2626 SERVICE_FABRIC ,
2727 DEFAULT_TO_VM ,
2828)
29+ from msal .token_cache import is_subdict_of
2930
3031
3132class ManagedIdentityTestCase (unittest .TestCase ):
@@ -60,7 +61,7 @@ def setUp(self):
6061 http_client = requests .Session (),
6162 )
6263
63- def _test_token_cache (self , app ):
64+ def assertCacheStatus (self , app ):
6465 cache = app ._token_cache ._cache
6566 self .assertEqual (1 , len (cache .get ("AccessToken" , [])), "Should have 1 AT" )
6667 at = list (cache ["AccessToken" ].values ())[0 ]
@@ -70,30 +71,55 @@ def _test_token_cache(self, app):
7071 "Should have expected client_id" )
7172 self .assertEqual ("managed_identity" , at ["realm" ], "Should have expected realm" )
7273
73- def _test_happy_path (self , app , mocked_http ):
74- result = app .acquire_token_for_client (resource = "R" )
74+ def _test_happy_path (self , app , mocked_http , expires_in , resource = "R" ):
75+ result = app .acquire_token_for_client (resource = resource )
7576 mocked_http .assert_called ()
76- self .assertEqual ({
77+ call_count = mocked_http .call_count
78+ expected_result = {
7779 "access_token" : "AT" ,
78- "expires_in" : 1234 ,
79- "resource" : "R" ,
8080 "token_type" : "Bearer" ,
81- }, result , "Should obtain a token response" )
81+ }
82+ self .assertTrue (
83+ is_subdict_of (expected_result , result ), # We will test refresh_on later
84+ "Should obtain a token response" )
85+ self .assertEqual (expires_in , result ["expires_in" ], "Should have expected expires_in" )
86+ if expires_in >= 7200 :
87+ expected_refresh_on = int (time .time () + expires_in / 2 )
88+ self .assertTrue (
89+ expected_refresh_on - 1 <= result ["refresh_on" ] <= expected_refresh_on + 1 ,
90+ "Should have a refresh_on time around the middle of the token's life" )
8291 self .assertEqual (
8392 result ["access_token" ],
84- app .acquire_token_for_client (resource = "R" ).get ("access_token" ),
93+ app .acquire_token_for_client (resource = resource ).get ("access_token" ),
8594 "Should hit the same token from cache" )
86- self ._test_token_cache (app )
95+
96+ self .assertCacheStatus (app )
97+
98+ result = app .acquire_token_for_client (resource = resource )
99+ self .assertEqual (
100+ call_count , mocked_http .call_count ,
101+ "No new call to the mocked http should be made for a cache hit" )
102+ self .assertTrue (
103+ is_subdict_of (expected_result , result ), # We will test refresh_on later
104+ "Should obtain a token response" )
105+ self .assertTrue (
106+ expires_in - 5 < result ["expires_in" ] <= expires_in ,
107+ "Should have similar expires_in" )
108+ if expires_in >= 7200 :
109+ self .assertTrue (
110+ expected_refresh_on - 5 < result ["refresh_on" ] <= expected_refresh_on ,
111+ "Should have a refresh_on time around the middle of the token's life" )
87112
88113
89114class VmTestCase (ClientTestCase ):
90115
91116 def test_happy_path (self ):
117+ expires_in = 7890 # We test a bigger than 7200 value here
92118 with patch .object (self .app ._http_client , "get" , return_value = MinimalResponse (
93119 status_code = 200 ,
94- text = '{"access_token": "AT", "expires_in": "1234 ", "resource": "R"}' ,
120+ text = '{"access_token": "AT", "expires_in": "%s ", "resource": "R"}' % expires_in ,
95121 )) as mocked_method :
96- self ._test_happy_path (self .app , mocked_method )
122+ self ._test_happy_path (self .app , mocked_method , expires_in )
97123
98124 def test_vm_error_should_be_returned_as_is (self ):
99125 raw_error = '{"raw": "error format is undefined"}'
@@ -110,12 +136,13 @@ def test_vm_error_should_be_returned_as_is(self):
110136class AppServiceTestCase (ClientTestCase ):
111137
112138 def test_happy_path (self ):
139+ expires_in = 1234
113140 with patch .object (self .app ._http_client , "get" , return_value = MinimalResponse (
114141 status_code = 200 ,
115142 text = '{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
116- int (time .time ()) + 1234 ),
143+ int (time .time ()) + expires_in ),
117144 )) as mocked_method :
118- self ._test_happy_path (self .app , mocked_method )
145+ self ._test_happy_path (self .app , mocked_method , expires_in )
119146
120147 def test_app_service_error_should_be_normalized (self ):
121148 raw_error = '{"statusCode": 500, "message": "error content is undefined"}'
@@ -134,12 +161,13 @@ def test_app_service_error_should_be_normalized(self):
134161class MachineLearningTestCase (ClientTestCase ):
135162
136163 def test_happy_path (self ):
164+ expires_in = 1234
137165 with patch .object (self .app ._http_client , "get" , return_value = MinimalResponse (
138166 status_code = 200 ,
139167 text = '{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
140- int (time .time ()) + 1234 ),
168+ int (time .time ()) + expires_in ),
141169 )) as mocked_method :
142- self ._test_happy_path (self .app , mocked_method )
170+ self ._test_happy_path (self .app , mocked_method , expires_in )
143171
144172 def test_machine_learning_error_should_be_normalized (self ):
145173 raw_error = '{"error": "placeholder", "message": "placeholder"}'
@@ -162,12 +190,14 @@ def test_machine_learning_error_should_be_normalized(self):
162190class ServiceFabricTestCase (ClientTestCase ):
163191
164192 def _test_happy_path (self , app ):
193+ expires_in = 1234
165194 with patch .object (app ._http_client , "get" , return_value = MinimalResponse (
166195 status_code = 200 ,
167196 text = '{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
168- int (time .time ()) + 1234 ),
197+ int (time .time ()) + expires_in ),
169198 )) as mocked_method :
170- super (ServiceFabricTestCase , self )._test_happy_path (app , mocked_method )
199+ super (ServiceFabricTestCase , self )._test_happy_path (
200+ app , mocked_method , expires_in )
171201
172202 def test_happy_path (self ):
173203 self ._test_happy_path (self .app )
@@ -212,15 +242,16 @@ class ArcTestCase(ClientTestCase):
212242 })
213243
214244 def test_happy_path (self , mocked_stat ):
245+ expires_in = 1234
215246 with patch .object (self .app ._http_client , "get" , side_effect = [
216247 self .challenge ,
217248 MinimalResponse (
218249 status_code = 200 ,
219- text = '{"access_token": "AT", "expires_in": "1234 ", "resource": "R"}' ,
250+ text = '{"access_token": "AT", "expires_in": "%s ", "resource": "R"}' % expires_in ,
220251 ),
221252 ]) as mocked_method :
222253 try :
223- super ( ArcTestCase , self ) ._test_happy_path (self .app , mocked_method )
254+ self ._test_happy_path (self .app , mocked_method , expires_in )
224255 mocked_stat .assert_called_with (os .path .join (
225256 _supported_arc_platforms_and_their_prefixes [sys .platform ],
226257 "foo.key" ))
0 commit comments