Skip to content

Commit 1ddddf4

Browse files
committed
Device Flow
1 parent 35a7726 commit 1ddddf4

File tree

2 files changed

+87
-15
lines changed

2 files changed

+87
-15
lines changed

msal/application.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
import time
2+
try: # Python 2
3+
from urlparse import urljoin
4+
except: # Python 3
5+
from urllib.parse import urljoin
6+
import logging
27

38
from oauth2cli import Client
49
from .authority import Authority
@@ -61,7 +66,11 @@ def __init__(
6166
default_body["client_info"] = 1
6267
self.client = Client(
6368
self.client_id,
64-
configuration={"token_endpoint": self.authority.token_endpoint},
69+
configuration={
70+
"token_endpoint": self.authority.token_endpoint,
71+
"device_authorization_endpoint": urljoin(
72+
self.authority.token_endpoint, "devicecode"),
73+
},
6574
default_body=default_body,
6675
on_obtaining_tokens=self.token_cache.add,
6776
on_removing_rt=self.token_cache.remove_rt,
@@ -219,7 +228,32 @@ def acquire_token_silent(
219228
scope=decorate_scope(scope, self.client_id))
220229
if "error" not in response:
221230
return response
222-
231+
logging.debug(
232+
"Refresh failed. {error}: {error_description}".format(**response))
233+
234+
def initiate_device_flow(self, scope=None, **kwargs):
235+
return self.client.initiate_device_flow(
236+
scope=decorate_scope(scope, self.client_id) if scope else None,
237+
**kwargs)
238+
239+
def acquire_token_by_device_flow(
240+
self, flow, exit_condition=lambda: True, **kwargs):
241+
"""Obtain token by a device flow object, with optional polling effect.
242+
243+
Args:
244+
flow (dict):
245+
An object previously generated by initiate_device_flow(...).
246+
exit_condition (Callable):
247+
This method implements a loop to provide polling effect.
248+
The loop's exit condition is calculated by this callback.
249+
The default callback makes the loop run only once, i.e. no polling.
250+
"""
251+
return self.client.obtain_token_by_device_flow(
252+
flow, exit_condition=exit_condition,
253+
data={"code": flow["device_code"]}, # 2018-10-4 Hack:
254+
# during transition period,
255+
# service seemingly need both device_code and code parameter.
256+
**kwargs)
223257

224258
class PublicClientApplication(ClientApplication): # browser app or mobile app
225259

@@ -254,8 +288,6 @@ def acquire_token(
254288
# It will handle the TWO round trips of Authorization Code Grant flow.
255289
raise NotImplemented()
256290

257-
# TODO: Support Device Code flow
258-
259291

260292
class ConfidentialClientApplication(ClientApplication): # server-side web app
261293

tests/test_application.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@
1313
with open(CONFIG_FILE) as conf:
1414
CONFIG = json.load(conf)
1515

16+
logger = logging.getLogger(__file__)
17+
logging.basicConfig(level=logging.DEBUG)
18+
19+
20+
class Oauth2TestCase(unittest.TestCase):
21+
22+
def assertLoosely(self, response, assertion=None,
23+
skippable_errors=("invalid_grant", "interaction_required")):
24+
if response.get("error") in skippable_errors:
25+
logger.debug("Response = %s", response)
26+
# Some of these errors are configuration issues, not library issues
27+
raise unittest.SkipTest(response.get("error_description"))
28+
else:
29+
if assertion is None:
30+
assertion = lambda: self.assertIn(
31+
"access_token", response,
32+
"{error}: {error_description}".format(
33+
# Do explicit response.get(...) rather than **response
34+
error=response.get("error"),
35+
error_description=response.get("error_description")))
36+
assertion()
37+
1638

1739
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
1840
class TestConfidentialClientApplication(unittest.TestCase):
@@ -58,23 +80,14 @@ def test_username_password(self):
5880

5981

6082
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
61-
class TestClientApplication(unittest.TestCase):
83+
class TestClientApplication(Oauth2TestCase):
6284

6385
@classmethod
6486
def setUpClass(cls):
6587
cls.app = ClientApplication(
6688
CONFIG["client_id"], client_credential=CONFIG.get("client_secret"),
6789
authority=CONFIG.get("authority"))
6890

69-
def assertLoosely(self, result):
70-
if "error" in result:
71-
# Some of these errors are configuration issues, not library issues
72-
if result["error"] == "invalid_grant":
73-
raise unittest.SkipTest(result.get("error_description"))
74-
self.assertEqual(result["error"], "interaction_required")
75-
else:
76-
self.assertIn('access_token', result)
77-
7891
@unittest.skipUnless("scope" in CONFIG, "Missing scope")
7992
def test_auth_code(self):
8093
from oauth2cli.authcode import obtain_auth_code
@@ -88,8 +101,18 @@ def test_auth_code(self):
88101
result = self.app.acquire_token_with_authorization_code(
89102
ac, CONFIG["scope"], redirect_uri=redirect_uri)
90103
logging.debug("cache = %s", json.dumps(self.app.token_cache._cache, indent=4))
91-
self.assertIn("access_token", result, "We should receive AT by auth code")
104+
self.assertIn(
105+
"access_token", result,
106+
"{error}: {error_description}".format(
107+
# Note: No interpolation here, cause error won't always present
108+
error=result.get("error"),
109+
error_description=result.get("error_description")))
110+
111+
self.assertCacheWorks(result)
92112

113+
114+
def assertCacheWorks(self, result_from_wire):
115+
result = result_from_wire
93116
# Going to test acquire_token_silent(...) to locate an AT from cache
94117
# In practice, you may want to filter based on its "username" field
95118
accounts = self.app.get_accounts()
@@ -109,3 +132,20 @@ def test_auth_code(self):
109132
self.assertNotEqual(result['access_token'], result_from_cache['access_token'],
110133
"We should get a fresh AT (via RT)")
111134

135+
def test_device_flow(self):
136+
flow = self.app.initiate_device_flow(scope=CONFIG.get("scope"))
137+
logging.warn(flow["message"])
138+
139+
duration = 30
140+
logging.warn("We will wait up to %d seconds for you to sign in" % duration)
141+
result = self.app.acquire_token_by_device_flow(
142+
flow,
143+
exit_condition=lambda end=time.time() + duration: time.time() > end)
144+
self.assertLoosely(
145+
result,
146+
assertion=lambda: self.assertIn('access_token', result),
147+
skippable_errors=self.app.client.DEVICE_FLOW_RETRIABLE_ERRORS)
148+
149+
if "access_token" in result:
150+
self.assertCacheWorks(result)
151+

0 commit comments

Comments
 (0)