Skip to content

Commit 19938ef

Browse files
committed
Merge remote branch
2 parents 53949a9 + dac051a commit 19938ef

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

msal/oauth2cli/oauth2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
108108
data=None, # All relevant data, which will go into the http body
109109
headers=None, # a dict to be sent as request headers
110110
timeout=None,
111+
post=None, # A callable to replace requests.post(), for testing.
112+
# Such as: lambda url, **kwargs:
113+
# Mock(status_code=200, json=Mock(return_value={}))
111114
**kwargs # Relay all extra parameters to underlying requests
112115
): # Returns the json object came from the OAUTH2 response
113116
_data = {'client_id': self.client_id, 'grant_type': grant_type}
@@ -133,7 +136,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
133136
raise ValueError("token_endpoint not found in configuration")
134137
_headers = {'Accept': 'application/json'}
135138
_headers.update(headers or {})
136-
resp = self.session.post(
139+
resp = (post or self.session.post)(
137140
self.configuration["token_endpoint"],
138141
headers=_headers, params=params, data=_data, auth=auth,
139142
timeout=timeout or self.timeout,
@@ -393,16 +396,18 @@ def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):
393396

394397
def obtain_token_by_refresh_token(self, token_item, scope=None,
395398
rt_getter=lambda token_item: token_item["refresh_token"],
399+
on_removing_rt=None,
396400
**kwargs):
397401
# type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict
398402
"""This is an "overload" which accepts a refresh token item as a dict,
399403
therefore this method can relay refresh_token item to event listeners.
400404
401-
:param refresh_token_item: A refresh token item came from storage
405+
:param token_item: A refresh token item came from storage
402406
:param scope: If omitted, is treated as equal to the scope originally
403407
granted by the resource ownser,
404408
according to https://tools.ietf.org/html/rfc6749#section-6
405409
:param rt_getter: A callable used to extract the RT from token_item
410+
:param on_removing_rt: If absent, fall back to the one defined in initialization
406411
"""
407412
if isinstance(token_item, str):
408413
# Satisfy the L of SOLID, although we expect caller uses a dict
@@ -412,7 +417,7 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
412417
resp = super(Client, self).obtain_token_by_refresh_token(
413418
rt_getter(token_item), scope=scope, **kwargs)
414419
if resp.get('error') == 'invalid_grant':
415-
self.on_removing_rt(token_item) # Discard old RT
420+
(on_removing_rt or self.on_removing_rt)(token_item) # Discard old RT
416421
if 'refresh_token' in resp:
417422
self.on_updating_rt(token_item, resp['refresh_token'])
418423
return resp

0 commit comments

Comments
 (0)