Skip to content

Commit f2ce2ee

Browse files
committed
obtain_token_by_refresh_token() accepts a string RT for migration purpose
1 parent a1cd299 commit f2ce2ee

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

oauth2cli/oauth2.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -380,26 +380,34 @@ def __init__(self,
380380
self.on_removing_rt = on_removing_rt
381381
self.on_updating_rt = on_updating_rt
382382

383-
def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):
383+
def _obtain_token(self, grant_type, params=None, data=None,
384+
rt_getter=lambda token_item: token_item["refresh_token"],
385+
*args, **kwargs):
386+
RT = "refresh_token"
387+
_data = data.copy() # to prevent side effect
388+
refresh_token = _data.get(RT)
389+
if grant_type == RT and isinstance(refresh_token, dict):
390+
_data[RT] = rt_getter(refresh_token) # Put raw RT in _data
384391
resp = super(Client, self)._obtain_token(
385-
grant_type, params, data, *args, **kwargs)
392+
grant_type, params, _data, *args, **kwargs)
386393
if "error" not in resp:
387394
_resp = resp.copy()
388-
if grant_type == "refresh_token" and "refresh_token" in _resp:
389-
_resp.pop("refresh_token") # We'll handle this in its own method
395+
if grant_type == RT and RT in _resp and isinstance(refresh_token, dict):
396+
_resp.pop(RT) # So we skip it in on_obtaining_tokens(); it will
397+
# be handled in self.obtain_token_by_refresh_token()
390398
if "scope" in _resp:
391399
scope = _resp["scope"].split() # It is conceptually a set,
392400
# but we represent it as a list which can be persisted to JSON
393401
else:
394402
# TODO: Deal with absent scope in authorization grant
395-
scope = data.get("scope")
403+
scope = _data.get("scope")
396404
self.on_obtaining_tokens({
397405
"client_id": self.client_id,
398406
"scope": scope,
399407
"token_endpoint": self.configuration["token_endpoint"],
400408
"grant_type": grant_type, # can be used to know an IdToken-less
401409
# response is for an app or for a user
402-
"response": _resp, "params": params, "data": data,
410+
"response": _resp, "params": params, "data": _data,
403411
})
404412
return resp
405413

@@ -411,26 +419,29 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
411419
"""This is an "overload" which accepts a refresh token item as a dict,
412420
therefore this method can relay refresh_token item to event listeners.
413421
414-
:param token_item: A refresh token item came from storage
422+
:param token_item:
423+
A refresh token item as a dict, came from the cache managed by this lib.
424+
425+
Alternatively, you can still use a refresh token (RT) as a string,
426+
supposedly came from a token cache managed by a different library,
427+
then this library will store the new RT (if Authority Server issued one)
428+
into this lib's cache. This is a way to migrate from other lib to us.
415429
:param scope: If omitted, is treated as equal to the scope originally
416430
granted by the resource ownser,
417431
according to https://tools.ietf.org/html/rfc6749#section-6
418432
:param rt_getter: A callable used to extract the RT from token_item
419433
:param on_removing_rt: If absent, fall back to the one defined in initialization
420434
"""
421-
if isinstance(token_item, str):
422-
# Satisfy the L of SOLID, although we expect caller uses a dict
423-
return super(Client, self).obtain_token_by_refresh_token(
424-
token_item, scope=scope, **kwargs)
435+
resp = super(Client, self).obtain_token_by_refresh_token(
436+
token_item, scope=scope,
437+
rt_getter=rt_getter, # Wire up this for _obtain_token()
438+
**kwargs)
425439
if isinstance(token_item, dict):
426-
resp = super(Client, self).obtain_token_by_refresh_token(
427-
rt_getter(token_item), scope=scope, **kwargs)
428440
if resp.get('error') == 'invalid_grant':
429441
(on_removing_rt or self.on_removing_rt)(token_item) # Discard old RT
430442
if 'refresh_token' in resp:
431443
self.on_updating_rt(token_item, resp['refresh_token'])
432-
return resp
433-
raise ValueError("token_item should not be a type %s" % type(token_item))
444+
return resp
434445

435446
def obtain_token_by_assertion(
436447
self, assertion, grant_type, scope=None, **kwargs):

0 commit comments

Comments
 (0)