99DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
1010
1111
12- def _hash (raw ):
13- return sha256 (repr (raw ).encode ("utf-8" )).hexdigest ()
14-
15-
16- def _parse_http_429_5xx_retry_after (result = None , ** ignored ):
17- """Return seconds to throttle"""
18- assert result is not None , """
19- The signature defines it with a default value None,
20- only because the its shape is already decided by the
21- IndividualCache's.__call__().
22- In actual code path, the result parameter here won't be None.
23- """
24- response = result
25- lowercase_headers = {k .lower (): v for k , v in getattr (
26- # Historically, MSAL's HttpResponse does not always have headers
27- response , "headers" , {}).items ()}
28- if not (response .status_code == 429 or response .status_code >= 500
29- or "retry-after" in lowercase_headers ):
30- return 0 # Quick exit
31- default = 60 # Recommended at the end of
32- # https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview
33- retry_after = lowercase_headers .get ("retry-after" , default )
34- try :
35- # AAD's retry_after uses integer format only
36- # https://stackoverflow.microsoft.com/questions/264931/264932
37- delay_seconds = int (retry_after )
38- except ValueError :
39- delay_seconds = default
40- return min (3600 , delay_seconds )
12+ class RetryAfterParser (object ):
13+ def __init__ (self , default_value = None ):
14+ self ._default_value = 5 if default_value is None else default_value
15+
16+ def parse (self , * , result , ** ignored ):
17+ """Return seconds to throttle"""
18+ response = result
19+ lowercase_headers = {k .lower (): v for k , v in getattr (
20+ # Historically, MSAL's HttpResponse does not always have headers
21+ response , "headers" , {}).items ()}
22+ if not (response .status_code == 429 or response .status_code >= 500
23+ or "retry-after" in lowercase_headers ):
24+ return 0 # Quick exit
25+ retry_after = lowercase_headers .get ("retry-after" , self ._default_value )
26+ try :
27+ # AAD's retry_after uses integer format only
28+ # https://stackoverflow.microsoft.com/questions/264931/264932
29+ delay_seconds = int (retry_after )
30+ except ValueError :
31+ delay_seconds = self ._default_value
32+ return min (3600 , delay_seconds )
4133
4234
4335def _extract_data (kwargs , key , default = None ):
@@ -53,7 +45,7 @@ class ThrottledHttpClientBase(object):
5345
5446 The subclass should implement post() and/or get()
5547 """
56- def __init__ (self , http_client , http_cache ):
48+ def __init__ (self , http_client , * , http_cache = None ):
5749 self .http_client = http_client
5850 self ._expiring_mapping = ExpiringMapping ( # It will automatically clean up
5951 mapping = http_cache if http_cache is not None else {},
@@ -70,10 +62,14 @@ def get(self, *args, **kwargs):
7062 def close (self ):
7163 return self .http_client .close ()
7264
65+ @staticmethod
66+ def _hash (raw ):
67+ return sha256 (repr (raw ).encode ("utf-8" )).hexdigest ()
68+
7369
7470class ThrottledHttpClient (ThrottledHttpClientBase ):
75- def __init__ (self , http_client , http_cache ):
76- super (ThrottledHttpClient , self ).__init__ (http_client , http_cache )
71+ def __init__ (self , http_client , * , default_throttle_time = None , ** kwargs ):
72+ super (ThrottledHttpClient , self ).__init__ (http_client , ** kwargs )
7773
7874 _post = http_client .post # We'll patch _post, and keep original post() intact
7975
@@ -86,22 +82,22 @@ def __init__(self, http_client, http_cache):
8682 args [0 ], # It is the url, typically containing authority and tenant
8783 _extract_data (kwargs , "client_id" ), # Per internal specs
8884 _extract_data (kwargs , "scope" ), # Per internal specs
89- _hash (
85+ self . _hash (
9086 # The followings are all approximations of the "account" concept
9187 # to support per-account throttling.
9288 # TODO: We may want to disable it for confidential client, though
9389 _extract_data (kwargs , "refresh_token" , # "account" during refresh
9490 _extract_data (kwargs , "code" , # "account" of auth code grant
9591 _extract_data (kwargs , "username" )))), # "account" of ROPC
9692 ),
97- expires_in = _parse_http_429_5xx_retry_after ,
93+ expires_in = RetryAfterParser ( default_throttle_time or 5 ). parse ,
9894 )(_post )
9995
10096 _post = IndividualCache ( # It covers the "UI required cache"
10197 mapping = self ._expiring_mapping ,
10298 key_maker = lambda func , args , kwargs : "POST {} hash={} 400" .format (
10399 args [0 ], # It is the url, typically containing authority and tenant
104- _hash (
100+ self . _hash (
105101 # Here we use literally all parameters, even those short-lived
106102 # parameters containing timestamps (WS-Trust or POP assertion),
107103 # because they will automatically be cleaned up by ExpiringMapping.
@@ -140,7 +136,7 @@ def __init__(self, http_client, http_cache):
140136 mapping = self ._expiring_mapping ,
141137 key_maker = lambda func , args , kwargs : "GET {} hash={} 2xx" .format (
142138 args [0 ], # It is the url, sometimes containing inline params
143- _hash (kwargs .get ("params" , "" )),
139+ self . _hash (kwargs .get ("params" , "" )),
144140 ),
145141 expires_in = lambda result = None , ** ignored :
146142 3600 * 24 if 200 <= result .status_code < 300 else 0 ,
0 commit comments