1313 with open (CONFIG_FILE ) as conf :
1414 CONFIG = json .load (conf )
1515
16+ logger = logging .getLogger (__file__ )
17+ logging .basicConfig (level = logging .DEBUG )
18+
19+
20+ class Oauth2TestCase (unittest .TestCase ):
21+
22+ def assertLoosely (self , response , assertion = None ,
23+ skippable_errors = ("invalid_grant" , "interaction_required" )):
24+ if response .get ("error" ) in skippable_errors :
25+ logger .debug ("Response = %s" , response )
26+ # Some of these errors are configuration issues, not library issues
27+ raise unittest .SkipTest (response .get ("error_description" ))
28+ else :
29+ if assertion is None :
30+ assertion = lambda : self .assertIn (
31+ "access_token" , response ,
32+ "{error}: {error_description}" .format (
33+ # Do explicit response.get(...) rather than **response
34+ error = response .get ("error" ),
35+ error_description = response .get ("error_description" )))
36+ assertion ()
37+
1638
1739@unittest .skipUnless ("client_id" in CONFIG , "client_id missing" )
1840class TestConfidentialClientApplication (unittest .TestCase ):
@@ -58,23 +80,14 @@ def test_username_password(self):
5880
5981
6082@unittest .skipUnless ("client_id" in CONFIG , "client_id missing" )
61- class TestClientApplication (unittest . TestCase ):
83+ class TestClientApplication (Oauth2TestCase ):
6284
6385 @classmethod
6486 def setUpClass (cls ):
6587 cls .app = ClientApplication (
6688 CONFIG ["client_id" ], client_credential = CONFIG .get ("client_secret" ),
6789 authority = CONFIG .get ("authority" ))
6890
69- def assertLoosely (self , result ):
70- if "error" in result :
71- # Some of these errors are configuration issues, not library issues
72- if result ["error" ] == "invalid_grant" :
73- raise unittest .SkipTest (result .get ("error_description" ))
74- self .assertEqual (result ["error" ], "interaction_required" )
75- else :
76- self .assertIn ('access_token' , result )
77-
7891 @unittest .skipUnless ("scope" in CONFIG , "Missing scope" )
7992 def test_auth_code (self ):
8093 from oauth2cli .authcode import obtain_auth_code
@@ -88,8 +101,18 @@ def test_auth_code(self):
88101 result = self .app .acquire_token_with_authorization_code (
89102 ac , CONFIG ["scope" ], redirect_uri = redirect_uri )
90103 logging .debug ("cache = %s" , json .dumps (self .app .token_cache ._cache , indent = 4 ))
91- self .assertIn ("access_token" , result , "We should receive AT by auth code" )
104+ self .assertIn (
105+ "access_token" , result ,
106+ "{error}: {error_description}" .format (
107+ # Note: No interpolation here, cause error won't always present
108+ error = result .get ("error" ),
109+ error_description = result .get ("error_description" )))
110+
111+ self .assertCacheWorks (result )
92112
113+
114+ def assertCacheWorks (self , result_from_wire ):
115+ result = result_from_wire
93116 # Going to test acquire_token_silent(...) to locate an AT from cache
94117 # In practice, you may want to filter based on its "username" field
95118 accounts = self .app .get_accounts ()
@@ -109,3 +132,20 @@ def test_auth_code(self):
109132 self .assertNotEqual (result ['access_token' ], result_from_cache ['access_token' ],
110133 "We should get a fresh AT (via RT)" )
111134
135+ def test_device_flow (self ):
136+ flow = self .app .initiate_device_flow (scope = CONFIG .get ("scope" ))
137+ logging .warn (flow ["message" ])
138+
139+ duration = 30
140+ logging .warn ("We will wait up to %d seconds for you to sign in" % duration )
141+ result = self .app .acquire_token_by_device_flow (
142+ flow ,
143+ exit_condition = lambda end = time .time () + duration : time .time () > end )
144+ self .assertLoosely (
145+ result ,
146+ assertion = lambda : self .assertIn ('access_token' , result ),
147+ skippable_errors = self .app .client .DEVICE_FLOW_RETRIABLE_ERRORS )
148+
149+ if "access_token" in result :
150+ self .assertCacheWorks (result )
151+
0 commit comments