33
44from .individual_cache import _IndividualCache as IndividualCache
55from .individual_cache import _ExpiringMapping as ExpiringMapping
6+ from .oauth2cli .http import Response
7+ from .exceptions import MsalServiceError
68
79
810# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
911DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
1012
1113
14+ def _get_headers (response ):
15+ # MSAL's HttpResponse did not have headers until 1.23.0
16+ # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61
17+ # This helper ensures graceful degradation to {} without exception
18+ return getattr (response , "headers" , {})
19+
20+
1221class RetryAfterParser (object ):
22+ FIELD_NAME_LOWER = "Retry-After" .lower ()
1323 def __init__ (self , default_value = None ):
1424 self ._default_value = 5 if default_value is None else default_value
1525
1626 def parse (self , * , result , ** ignored ):
1727 """Return seconds to throttle"""
1828 response = result
19- lowercase_headers = {k .lower (): v for k , v in getattr (
20- # Historically, MSAL's HttpResponse does not always have headers
21- response , "headers" , {}).items ()}
29+ lowercase_headers = {k .lower (): v for k , v in _get_headers (response ).items ()}
2230 if not (response .status_code == 429 or response .status_code >= 500
23- or "retry-after" in lowercase_headers ):
31+ or self . FIELD_NAME_LOWER in lowercase_headers ):
2432 return 0 # Quick exit
25- retry_after = lowercase_headers .get ("retry-after" , self ._default_value )
33+ retry_after = lowercase_headers .get (self . FIELD_NAME_LOWER , self ._default_value )
2634 try :
2735 # AAD's retry_after uses integer format only
2836 # https://stackoverflow.microsoft.com/questions/264931/264932
@@ -37,27 +45,55 @@ def _extract_data(kwargs, key, default=None):
3745 return data .get (key ) if isinstance (data , dict ) else default
3846
3947
48+ class NormalizedResponse (Response ):
49+ """A http response with the shape defined in Response,
50+ but contains only the data we will store in cache.
51+ """
52+ def __init__ (self , raw_response ):
53+ super ().__init__ ()
54+ self .status_code = raw_response .status_code
55+ self .text = raw_response .text
56+ self .headers = {
57+ k .lower (): v for k , v in _get_headers (raw_response ).items ()
58+ # Attempted storing only a small set of headers (such as Retry-After),
59+ # but it tends to lead to missing information (such as WWW-Authenticate).
60+ # So we store all headers, which are expected to contain only public info,
61+ # because we throttle only error responses and public responses.
62+ }
63+
64+ ## Note: Don't use the following line,
65+ ## because when being pickled, it will indirectly pickle the whole raw_response
66+ # self.raise_for_status = raw_response.raise_for_status
67+ def raise_for_status (self ):
68+ if self .status_code >= 400 :
69+ raise MsalServiceError ("HTTP Error: {}" .format (self .status_code ))
70+
71+
4072class ThrottledHttpClientBase (object ):
4173 """Throttle the given http_client by storing and retrieving data from cache.
4274
43- This wrapper exists so that our patching post() and get() would prevent
44- re-patching side effect when/if same http_client being reused.
75+ This base exists so that:
76+ 1. These base post() and get() will return a NormalizedResponse
77+ 2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.
4578
46- The subclass should implement post() and/or get()
79+ Subclasses shall only need to dynamically decorate their post() and get() methods
80+ in their __init__() method.
4781 """
4882 def __init__ (self , http_client , * , http_cache = None ):
49- self .http_client = http_client
83+ self .http_client = http_client .http_client if isinstance (
84+ # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
85+ http_client , ThrottledHttpClientBase ) else http_client
5086 self ._expiring_mapping = ExpiringMapping ( # It will automatically clean up
5187 mapping = http_cache if http_cache is not None else {},
5288 capacity = 1024 , # To prevent cache blowing up especially for CCA
5389 lock = Lock (), # TODO: This should ideally also allow customization
5490 )
5591
5692 def post (self , * args , ** kwargs ):
57- return self .http_client .post (* args , ** kwargs )
93+ return NormalizedResponse ( self .http_client .post (* args , ** kwargs ) )
5894
5995 def get (self , * args , ** kwargs ):
60- return self .http_client .get (* args , ** kwargs )
96+ return NormalizedResponse ( self .http_client .get (* args , ** kwargs ) )
6197
6298 def close (self ):
6399 return self .http_client .close ()
@@ -68,12 +104,11 @@ def _hash(raw):
68104
69105
70106class ThrottledHttpClient (ThrottledHttpClientBase ):
71- def __init__ (self , http_client , * , default_throttle_time = None , ** kwargs ):
72- super (ThrottledHttpClient , self ).__init__ (http_client , ** kwargs )
73-
74- _post = http_client .post # We'll patch _post, and keep original post() intact
75-
76- _post = IndividualCache (
107+ """A throttled http client that is used by MSAL's non-managed identity clients."""
108+ def __init__ (self , * args , default_throttle_time = None , ** kwargs ):
109+ """Decorate self.post() and self.get() dynamically"""
110+ super (ThrottledHttpClient , self ).__init__ (* args , ** kwargs )
111+ self .post = IndividualCache (
77112 # Internal specs requires throttling on at least token endpoint,
78113 # here we have a generic patch for POST on all endpoints.
79114 mapping = self ._expiring_mapping ,
@@ -91,9 +126,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
91126 _extract_data (kwargs , "username" )))), # "account" of ROPC
92127 ),
93128 expires_in = RetryAfterParser (default_throttle_time or 5 ).parse ,
94- )(_post )
129+ )(self . post )
95130
96- _post = IndividualCache ( # It covers the "UI required cache"
131+ self . post = IndividualCache ( # It covers the "UI required cache"
97132 mapping = self ._expiring_mapping ,
98133 key_maker = lambda func , args , kwargs : "POST {} hash={} 400" .format (
99134 args [0 ], # It is the url, typically containing authority and tenant
@@ -125,12 +160,10 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
125160 isinstance (kwargs .get ("data" ), dict )
126161 and kwargs ["data" ].get ("grant_type" ) == DEVICE_AUTH_GRANT
127162 )
128- and "retry-after" not in set ( # Leave it to the Retry-After decorator
129- h .lower () for h in getattr (result , "headers" , {}). keys ( ))
163+ and RetryAfterParser . FIELD_NAME_LOWER not in set ( # Otherwise leave it to the Retry-After decorator
164+ h .lower () for h in _get_headers (result ))
130165 else 0 ,
131- )(_post )
132-
133- self .post = _post
166+ )(self .post )
134167
135168 self .get = IndividualCache ( # Typically those discovery GETs
136169 mapping = self ._expiring_mapping ,
@@ -140,9 +173,4 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
140173 ),
141174 expires_in = lambda result = None , ** ignored :
142175 3600 * 24 if 200 <= result .status_code < 300 else 0 ,
143- )(http_client .get )
144-
145- # The following 2 methods have been defined dynamically by __init__()
146- #def post(self, *args, **kwargs): pass
147- #def get(self, *args, **kwargs): pass
148-
176+ )(self .get )
0 commit comments