@@ -454,17 +454,20 @@ def __init__(self,
454454 self .on_removing_rt = on_removing_rt
455455 self .on_updating_rt = on_updating_rt
456456
457- def _obtain_token (self , grant_type , params = None , data = None , * args , ** kwargs ):
458- RT = "refresh_token"
457+ def _obtain_token (
458+ self , grant_type , params = None , data = None ,
459+ also_save_rt = False ,
460+ * args , ** kwargs ):
459461 _data = data .copy () # to prevent side effect
460- refresh_token = _data .get (RT )
461462 resp = super (Client , self )._obtain_token (
462463 grant_type , params , _data , * args , ** kwargs )
463464 if "error" not in resp :
464465 _resp = resp .copy ()
465- if grant_type == RT and RT in _resp and isinstance (refresh_token , dict ):
466- _resp .pop (RT ) # So we skip it in on_obtaining_tokens(); it will
467- # be handled in self.obtain_token_by_refresh_token()
466+ RT = "refresh_token"
467+ if grant_type == RT and RT in _resp and not also_save_rt :
468+ # Then we skip it from on_obtaining_tokens();
469+ # Leave it to self.obtain_token_by_refresh_token()
470+ _resp .pop (RT , None )
468471 if "scope" in _resp :
469472 scope = _resp ["scope" ].split () # It is conceptually a set,
470473 # but we represent it as a list which can be persisted to JSON
@@ -486,6 +489,7 @@ def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):
486489 def obtain_token_by_refresh_token (self , token_item , scope = None ,
487490 rt_getter = lambda token_item : token_item ["refresh_token" ],
488491 on_removing_rt = None ,
492+ on_updating_rt = None ,
489493 ** kwargs ):
490494 # type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict
491495 """This is an overload which will trigger token storage callbacks.
@@ -503,16 +507,28 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
503507 according to https://tools.ietf.org/html/rfc6749#section-6
504508 :param rt_getter: A callable to translate the token_item to a raw RT string
505509 :param on_removing_rt: If absent, fall back to the one defined in initialization
510+
511+ :param on_updating_rt:
512+ Default to None, it will fall back to the one defined in initialization.
513+ This is the most common case.
514+
515+ As a special case, you can pass in a False,
516+ then this function will NOT trigger on_updating_rt() for RT UPDATE,
517+ instead it will allow the RT to be added by on_obtaining_tokens().
518+ This behavior is useful when you are migrating RTs from elsewhere
519+ into a token storage managed by this library.
506520 """
507521 resp = super (Client , self ).obtain_token_by_refresh_token (
508522 rt_getter (token_item )
509523 if not isinstance (token_item , string_types ) else token_item ,
510524 scope = scope ,
525+ also_save_rt = on_updating_rt is False ,
511526 ** kwargs )
512527 if resp .get ('error' ) == 'invalid_grant' :
513528 (on_removing_rt or self .on_removing_rt )(token_item ) # Discard old RT
514- if 'refresh_token' in resp :
515- self .on_updating_rt (token_item , resp ['refresh_token' ])
529+ RT = "refresh_token"
530+ if on_updating_rt is not False and RT in resp :
531+ (on_updating_rt or self .on_updating_rt )(token_item , resp [RT ])
516532 return resp
517533
518534 def obtain_token_by_assertion (
0 commit comments