Skip to content

Commit 54a6231

Browse files
committed
Merge branch 'cleanup-api-surface' into dev
2 parents ec803c7 + 5528fe3 commit 54a6231

File tree

2 files changed

+95
-106
lines changed

2 files changed

+95
-106
lines changed

msal/application.py

Lines changed: 81 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
logger = logging.getLogger(__name__)
2323

2424
def decorate_scope(
25-
scope, client_id,
26-
policy=None, # obsolete
25+
scopes, client_id,
2726
reserved_scope=frozenset(['openid', 'profile', 'offline_access'])):
28-
scope_set = set(scope) # Input scope is typically a list. Copy it to a set.
27+
if not isinstance(scopes, (list, set, tuple)):
28+
raise ValueError("The input scopes should be a list, tuple, or set")
29+
scope_set = set(scopes) # Input scopes is typically a list. Copy it to a set.
2930
if scope_set & reserved_scope:
3031
# These scopes are reserved for the API to provide good experience.
3132
# We could make the developer pass these and then if they do they will
@@ -53,7 +54,8 @@ class ClientApplication(object):
5354
def __init__(
5455
self, client_id,
5556
client_credential=None, authority=None, validate_authority=True,
56-
token_cache=None):
57+
token_cache=None,
58+
verify=True, proxies=None, timeout=None):
5759
"""
5860
:param client_credential: It can be a string containing client secret,
5961
or an X509 certificate container in this form:
@@ -70,6 +72,9 @@ def __init__(
7072
validate_authority)
7173
# Here the self.authority is not the same type as authority in input
7274
self.token_cache = token_cache or TokenCache()
75+
self.verify = verify
76+
self.proxies = proxies
77+
self.timeout = timeout
7378
self.client = self._build_client(client_credential, self.authority)
7479

7580
def _build_client(self, client_credential, authority):
@@ -104,13 +109,13 @@ def _build_client(self, client_credential, authority):
104109
on_obtaining_tokens=self.token_cache.add,
105110
on_removing_rt=self.token_cache.remove_rt,
106111
on_updating_rt=self.token_cache.update_rt,
107-
)
112+
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
108113

109114
def get_authorization_request_url(
110115
self,
111-
scope,
112-
additional_scope=frozenset([]), # Not yet supported
113-
login_hint=None,
116+
scopes, # type: list[str]
117+
# additional_scope=None, # type: Optional[list]
118+
login_hint=None, # type: Optional[str]
114119
state=None, # Recommended by OAuth2 for CSRF protection
115120
redirect_uri=None,
116121
authority=None, # By default, it will use self.authority;
@@ -119,15 +124,21 @@ def get_authorization_request_url(
119124
**kwargs):
120125
"""Constructs a URL for you to start a Authorization Code Grant.
121126
122-
:param scope: Scope refers to the resource that will be used in the
123-
resulting token's audience.
127+
:param scopes:
128+
Scopes requested to access a protected API (a resource).
129+
:param str state: Recommended by OAuth2 for CSRF protection.
130+
:param login_hint:
131+
Identifier of the user. Generally a User Principal Name (UPN).
132+
:param redirect_uri:
133+
Address to return to upon receiving a response from the authority.
134+
"""
135+
""" # TBD: this would only be meaningful in a new acquire_token_interactive()
124136
:param additional_scope: Additional scope is a concept only in AAD.
125137
It refers to other resources you might want to prompt to consent
126138
for in the same interaction, but for which you won't get back a
127139
token for in this particular operation.
128140
(Under the hood, we simply merge scope and additional_scope before
129141
sending them on the wire.)
130-
:param str state: Recommended by OAuth2 for CSRF protection.
131142
"""
132143
the_authority = Authority(authority) if authority else self.authority
133144
client = Client(
@@ -136,13 +147,13 @@ def get_authorization_request_url(
136147
return client.build_auth_request_uri(
137148
response_type="code", # Using Authorization Code grant
138149
redirect_uri=redirect_uri, state=state, login_hint=login_hint,
139-
scope=decorate_scope(scope, self.client_id),
150+
scope=decorate_scope(scopes, self.client_id),
140151
)
141152

142-
def acquire_token_with_authorization_code(
153+
def acquire_token_by_authorization_code(
143154
self,
144155
code,
145-
scope, # Syntactically required. STS accepts empty value though.
156+
scopes, # Syntactically required. STS accepts empty value though.
146157
redirect_uri=None,
147158
# REQUIRED, if the "redirect_uri" parameter was included in the
148159
# authorization request as described in Section 4.1.1, and their
@@ -151,7 +162,7 @@ def acquire_token_with_authorization_code(
151162
"""The second half of the Authorization Code Grant.
152163
153164
:param code: The authorization code returned from Authorization Server.
154-
:param scope:
165+
:param scopes:
155166
156167
If you requested user consent for multiple resources, here you will
157168
typically want to provide a subset of what you required in AuthCode.
@@ -171,38 +182,49 @@ def acquire_token_with_authorization_code(
171182
# So in theory, you can omit scope here when you were working with only
172183
# one scope. But, MSAL decorates your scope anyway, so they are never
173184
# really empty.
174-
assert isinstance(scope, list), "Invalid parameter type"
175-
return self.client.obtain_token_with_authorization_code(
185+
assert isinstance(scopes, list), "Invalid parameter type"
186+
return self.client.obtain_token_by_authorization_code(
176187
code, redirect_uri=redirect_uri,
177-
data={"scope": decorate_scope(scope, self.client_id)},
188+
data={"scope": decorate_scope(scopes, self.client_id)},
178189
)
179190

180-
def get_accounts(self):
181-
"""Returns a list of account objects that can later be used to find token.
191+
def get_accounts(self, username=None):
192+
"""Get a list of accounts which previously signed in, i.e. exists in cache.
182193
183-
Each account object is a dict containing a "username" field (among others)
184-
which can use to determine which account to use.
194+
An account can later be used in acquire_token_silent() to find its tokens.
195+
Each account is a dict. For now, we only document its "username" field.
196+
Your app can choose to display those information to end user,
197+
and allow them to choose one of them to proceed.
198+
199+
:param username:
200+
Filter accounts with this username only. Case insensitive.
185201
"""
186202
# The following implementation finds accounts only from saved accounts,
187203
# but does NOT correlate them with saved RTs. It probably won't matter,
188204
# because in MSAL universe, there are always Accounts and RTs together.
189-
return self.token_cache.find(
190-
self.token_cache.CredentialType.ACCOUNT,
191-
query={"environment": self.authority.instance})
205+
accounts = self.token_cache.find(
206+
self.token_cache.CredentialType.ACCOUNT,
207+
query={"environment": self.authority.instance})
208+
if username:
209+
# Federated account["username"] from AAD could contain mixed case
210+
lowercase_username = username.lower()
211+
accounts = [a for a in accounts
212+
if a["username"].lower() == lowercase_username]
213+
return accounts
192214

193215
def acquire_token_silent(
194-
self, scope,
216+
self, scopes,
195217
account=None, # one of the account object returned by get_accounts()
196218
authority=None, # See get_authorization_request_url()
197219
force_refresh=False, # To force refresh an Access Token (not a RT)
198220
**kwargs):
199-
assert isinstance(scope, list), "Invalid parameter type"
221+
assert isinstance(scopes, list), "Invalid parameter type"
200222
the_authority = Authority(authority) if authority else self.authority
201223

202224
if force_refresh == False:
203225
matches = self.token_cache.find(
204226
self.token_cache.CredentialType.ACCESS_TOKEN,
205-
target=scope,
227+
target=scopes,
206228
query={
207229
"client_id": self.client_id,
208230
"environment": the_authority.instance,
@@ -221,7 +243,7 @@ def acquire_token_silent(
221243

222244
matches = self.token_cache.find(
223245
self.token_cache.CredentialType.REFRESH_TOKEN,
224-
# target=scope, # AAD RTs are scope-independent
246+
# target=scopes, # AAD RTs are scope-independent
225247
query={
226248
"client_id": self.client_id,
227249
"environment": the_authority.instance,
@@ -230,62 +252,51 @@ def acquire_token_silent(
230252
})
231253
client = self._build_client(self.client_credential, the_authority)
232254
for entry in matches:
233-
response = client.obtain_token_with_refresh_token(
255+
response = client.obtain_token_by_refresh_token(
234256
entry, rt_getter=lambda token_item: token_item["secret"],
235-
scope=decorate_scope(scope, self.client_id))
257+
scope=decorate_scope(scopes, self.client_id))
236258
if "error" not in response:
237259
return response
238-
logging.debug(
260+
logger.debug(
239261
"Refresh failed. {error}: {error_description}".format(**response))
240262

241-
def initiate_device_flow(self, scope=None, **kwargs):
263+
def initiate_device_flow(self, scopes=None, **kwargs):
242264
return self.client.initiate_device_flow(
243-
scope=decorate_scope(scope, self.client_id) if scope else None,
265+
scope=decorate_scope(scopes or [], self.client_id),
244266
**kwargs)
245267

246-
def acquire_token_by_device_flow(
247-
self, flow, exit_condition=lambda: True, **kwargs):
248-
"""Obtain token by a device flow object, with optional polling effect.
268+
def acquire_token_by_device_flow(self, flow, **kwargs):
269+
"""Obtain token by a device flow object, with customizable polling effect.
249270
250271
Args:
251272
flow (dict):
252-
An object previously generated by initiate_device_flow(...).
253-
exit_condition (Callable):
254-
This method implements a loop to provide polling effect.
255-
The loop's exit condition is calculated by this callback.
256-
The default callback makes the loop run only once, i.e. no polling.
273+
A dict previously generated by initiate_device_flow(...).
274+
You can exit the polling loop early, by changing the value of
275+
its "expires_at" key to 0, at any time.
257276
"""
258277
return self.client.obtain_token_by_device_flow(
259-
flow, exit_condition=exit_condition,
278+
flow,
260279
data={"code": flow["device_code"]}, # 2018-10-4 Hack:
261280
# during transition period,
262281
# service seemingly need both device_code and code parameter.
263282
**kwargs)
264283

265284
class PublicClientApplication(ClientApplication): # browser app or mobile app
266285

267-
## TBD: what if redirect_uri is not needed in the constructor at all?
268-
## Device Code flow does not need redirect_uri anyway.
269-
270-
# OUT_OF_BAND = "urn:ietf:wg:oauth:2.0:oob"
271-
# def __init__(self, client_id, redirect_uri=None, **kwargs):
272-
# super(PublicClientApplication, self).__init__(client_id, **kwargs)
273-
# self.redirect_uri = redirect_uri or self.OUT_OF_BAND
274-
275-
def acquire_token_with_username_password(
276-
self, username, password, scope=None, **kwargs):
286+
def acquire_token_by_username_password(
287+
self, username, password, scopes=None, **kwargs):
277288
"""Gets a token for a given resource via user credentails."""
278-
scope = decorate_scope(scope, self.client_id)
289+
scopes = decorate_scope(scopes, self.client_id)
279290
if not self.authority.is_adfs:
280291
user_realm_result = self.authority.user_realm_discovery(username)
281292
if user_realm_result.get("account_type") == "Federated":
282-
return self._acquire_token_with_username_password_federated(
283-
user_realm_result, username, password, scope=scope, **kwargs)
284-
return self.client.obtain_token_with_username_password(
285-
username, password, scope=scope, **kwargs)
293+
return self._acquire_token_by_username_password_federated(
294+
user_realm_result, username, password, scopes=scopes, **kwargs)
295+
return self.client.obtain_token_by_username_password(
296+
username, password, scope=scopes, **kwargs)
286297

287-
def _acquire_token_with_username_password_federated(
288-
self, user_realm_result, username, password, scope=None, **kwargs):
298+
def _acquire_token_by_username_password_federated(
299+
self, user_realm_result, username, password, scopes=None, **kwargs):
289300
wstrust_endpoint = {}
290301
if user_realm_result.get("federation_metadata_url"):
291302
wstrust_endpoint = mex.send_request(
@@ -306,42 +317,20 @@ def _acquire_token_with_username_password_federated(
306317
if not grant_type:
307318
raise RuntimeError(
308319
"RSTR returned unknown token type: %s", wstrust_result.get("type"))
309-
return self.client.obtain_token_with_assertion(
320+
return self.client.obtain_token_by_assertion(
310321
b64encode(wstrust_result["token"]),
311-
grant_type=grant_type, scope=scope, **kwargs)
312-
313-
def acquire_token(
314-
self,
315-
scope,
316-
# additional_scope=None, # See also get_authorization_request_url()
317-
login_hint=None,
318-
ui_options=None,
319-
# user=None, # TBD: It exists in MSAL-dotnet but not in MSAL-Android
320-
policy='',
321-
authority=None, # See get_authorization_request_url()
322-
extra_query_params=None,
323-
):
324-
# It will handle the TWO round trips of Authorization Code Grant flow.
325-
raise NotImplemented()
322+
grant_type=grant_type, scope=scopes, **kwargs)
326323

327324

328325
class ConfidentialClientApplication(ClientApplication): # server-side web app
329326

330-
def acquire_token_for_client(self, scope, force_refresh=False):
331-
"""Acquires token from the service for the confidential client.
327+
def acquire_token_for_client(self, scopes, **kwargs):
328+
"""Acquires token from the service for the confidential client."""
329+
# TBD: force_refresh behavior
330+
return self.client.obtain_token_for_client(
331+
scope=scopes, # This grant flow requires no scope decoration
332+
**kwargs)
332333

333-
:param force_refresh:
334-
This method attempts to look up valid access token in the cache.
335-
If this parameter is set to True,
336-
this method will ignore the access token in the cache
337-
and attempt to acquire new access token using client credentials
338-
"""
339-
# TODO: force_refresh will be implemented after the cache mechanism is ready
340-
return self.client.obtain_token_with_client_credentials(
341-
scope=scope, # This grant flow requires no scope decoration
342-
)
343-
344-
def acquire_token_on_behalf_of(
345-
self, user_assertion, scope, authority=None, policy=''):
346-
pass
334+
def acquire_token_on_behalf_of(self, user_assertion, scopes, authority=None):
335+
raise NotImplementedError()
347336

tests/test_application.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,21 @@ def assertLoosely(self, response, assertion=None,
3737

3838
def assertCacheWorks(self, result_from_wire):
3939
result = result_from_wire
40-
# Going to test acquire_token_silent(...) to locate an AT from cache
41-
# In practice, you may want to filter based on its "username" field
42-
accounts = self.app.get_accounts()
40+
# You can filter by predefined username, or let end user to choose one
41+
accounts = self.app.get_accounts(username=CONFIG.get("username"))
4342
self.assertNotEqual(0, len(accounts))
43+
account = accounts[0]
44+
# Going to test acquire_token_silent(...) to locate an AT from cache
4445
result_from_cache = self.app.acquire_token_silent(
45-
CONFIG["scope"], account=accounts[0])
46+
CONFIG["scope"], account=account)
4647
self.assertIsNotNone(result_from_cache)
4748
self.assertEqual(result['access_token'], result_from_cache['access_token'],
4849
"We should get a cached AT")
4950

5051
# Going to test acquire_token_silent(...) to obtain an AT by a RT from cache
5152
self.app.token_cache._cache["AccessToken"] = {} # A hacky way to clear ATs
5253
result_from_cache = self.app.acquire_token_silent(
53-
CONFIG["scope"], account=accounts[0])
54+
CONFIG["scope"], account=account)
5455
self.assertIsNotNone(result_from_cache,
5556
"We should get a result from acquire_token_silent(...) call")
5657
self.assertNotEqual(result['access_token'], result_from_cache['access_token'],
@@ -66,7 +67,7 @@ def assertCacheWorks(self, result_from_wire, result_from_cache):
6667
result_from_wire['access_token'], result_from_cache['access_token'])
6768

6869
@unittest.skipUnless("client_secret" in CONFIG, "Missing client secret")
69-
def test_confidential_client_using_secret(self):
70+
def test_client_secret(self):
7071
app = ConfidentialClientApplication(
7172
CONFIG["client_id"], client_credential=CONFIG.get("client_secret"),
7273
authority=CONFIG.get("authority"))
@@ -76,7 +77,7 @@ def test_confidential_client_using_secret(self):
7677
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))
7778

7879
@unittest.skipUnless("client_certificate" in CONFIG, "Missing client cert")
79-
def test_confidential_client_using_certificate(self):
80+
def test_client_certificate(self):
8081
client_certificate = CONFIG["client_certificate"]
8182
assert ("private_key_path" in client_certificate
8283
and "thumbprint" in client_certificate)
@@ -99,8 +100,8 @@ class TestPublicClientApplication(Oauth2TestCase):
99100
def test_username_password(self):
100101
self.app = PublicClientApplication(
101102
CONFIG["client_id"], authority=CONFIG["authority"])
102-
result = self.app.acquire_token_with_username_password(
103-
CONFIG["username"], CONFIG["password"], scope=CONFIG.get("scope"))
103+
result = self.app.acquire_token_by_username_password(
104+
CONFIG["username"], CONFIG["password"], scopes=CONFIG.get("scope"))
104105
self.assertLoosely(result)
105106
self.assertCacheWorks(result)
106107

@@ -124,7 +125,7 @@ def test_auth_code(self):
124125
ac = obtain_auth_code(port, auth_uri=auth_request_uri)
125126
self.assertNotEqual(ac, None)
126127

127-
result = self.app.acquire_token_with_authorization_code(
128+
result = self.app.acquire_token_by_authorization_code(
128129
ac, CONFIG["scope"], redirect_uri=redirect_uri)
129130
logging.debug("cache = %s", json.dumps(self.app.token_cache._cache, indent=4))
130131
self.assertIn(
@@ -136,14 +137,13 @@ def test_auth_code(self):
136137
self.assertCacheWorks(result)
137138

138139
def test_device_flow(self):
139-
flow = self.app.initiate_device_flow(scope=CONFIG.get("scope"))
140+
flow = self.app.initiate_device_flow(scopes=CONFIG.get("scope"))
140141
logging.warn(flow["message"])
141142

142143
duration = 30
143144
logging.warn("We will wait up to %d seconds for you to sign in" % duration)
144-
result = self.app.acquire_token_by_device_flow(
145-
flow,
146-
exit_condition=lambda end=time.time() + duration: time.time() > end)
145+
flow["expires_at"] = time.time() + duration # Shorten the time for quick test
146+
result = self.app.acquire_token_by_device_flow(flow)
147147
self.assertLoosely(
148148
result,
149149
assertion=lambda: self.assertIn('access_token', result),

0 commit comments

Comments
 (0)