33
44from .individual_cache import _IndividualCache as IndividualCache
55from .individual_cache import _ExpiringMapping as ExpiringMapping
6+ from .oauth2cli .http import Response
7+ from .exceptions import MsalServiceError
68
79
810# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
911DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
1012
1113
1214class RetryAfterParser (object ):
15+ FIELD_NAME_LOWER = "Retry-After" .lower ()
1316 def __init__ (self , default_value = None ):
1417 self ._default_value = 5 if default_value is None else default_value
1518
@@ -20,9 +23,9 @@ def parse(self, *, result, **ignored):
2023 # Historically, MSAL's HttpResponse does not always have headers
2124 response , "headers" , {}).items ()}
2225 if not (response .status_code == 429 or response .status_code >= 500
23- or "retry-after" in lowercase_headers ):
26+ or self . FIELD_NAME_LOWER in lowercase_headers ):
2427 return 0 # Quick exit
25- retry_after = lowercase_headers .get ("retry-after" , self ._default_value )
28+ retry_after = lowercase_headers .get (self . FIELD_NAME_LOWER , self ._default_value )
2629 try :
2730 # AAD's retry_after uses integer format only
2831 # https://stackoverflow.microsoft.com/questions/264931/264932
@@ -37,27 +40,52 @@ def _extract_data(kwargs, key, default=None):
3740 return data .get (key ) if isinstance (data , dict ) else default
3841
3942
43+ class NormalizedResponse (Response ):
44+ """A http response with the shape defined in Response,
45+ but contains only the data we will store in cache.
46+ """
47+ def __init__ (self , raw_response ):
48+ super ().__init__ ()
49+ self .status_code = raw_response .status_code
50+ self .text = raw_response .text
51+ self .headers = { # Only keep the headers which ThrottledHttpClient cares about
52+ k : v for k , v in raw_response .headers .items ()
53+ if k .lower () == RetryAfterParser .FIELD_NAME_LOWER
54+ }
55+
56+ ## Note: Don't use the following line,
57+ ## because when being pickled, it will indirectly pickle the whole raw_response
58+ # self.raise_for_status = raw_response.raise_for_status
59+ def raise_for_status (self ):
60+ if self .status_code >= 400 :
61+ raise MsalServiceError ("HTTP Error: {}" .format (self .status_code ))
62+
63+
4064class ThrottledHttpClientBase (object ):
4165 """Throttle the given http_client by storing and retrieving data from cache.
4266
43- This wrapper exists so that our patching post() and get() would prevent
44- re-patching side effect when/if same http_client being reused.
67+ This base exists so that:
68+ 1. These base post() and get() will return a NormalizedResponse
69+ 2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.
4570
46- The subclass should implement post() and/or get()
71+ Subclasses shall only need to dynamically decorate their post() and get() methods
72+ in their __init__() method.
4773 """
4874 def __init__ (self , http_client , * , http_cache = None ):
49- self .http_client = http_client
75+ self .http_client = http_client .http_client if isinstance (
76+ # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
77+ http_client , ThrottledHttpClientBase ) else http_client
5078 self ._expiring_mapping = ExpiringMapping ( # It will automatically clean up
5179 mapping = http_cache if http_cache is not None else {},
5280 capacity = 1024 , # To prevent cache blowing up especially for CCA
5381 lock = Lock (), # TODO: This should ideally also allow customization
5482 )
5583
5684 def post (self , * args , ** kwargs ):
57- return self .http_client .post (* args , ** kwargs )
85+ return NormalizedResponse ( self .http_client .post (* args , ** kwargs ) )
5886
5987 def get (self , * args , ** kwargs ):
60- return self .http_client .get (* args , ** kwargs )
88+ return NormalizedResponse ( self .http_client .get (* args , ** kwargs ) )
6189
6290 def close (self ):
6391 return self .http_client .close ()
@@ -68,12 +96,11 @@ def _hash(raw):
6896
6997
7098class ThrottledHttpClient (ThrottledHttpClientBase ):
71- def __init__ (self , http_client , * , default_throttle_time = None , ** kwargs ):
72- super (ThrottledHttpClient , self ).__init__ (http_client , ** kwargs )
73-
74- _post = http_client .post # We'll patch _post, and keep original post() intact
75-
76- _post = IndividualCache (
99+ """A throttled http client that is used by MSAL's non-managed identity clients."""
100+ def __init__ (self , * args , default_throttle_time = None , ** kwargs ):
101+ """Decorate self.post() and self.get() dynamically"""
102+ super (ThrottledHttpClient , self ).__init__ (* args , ** kwargs )
103+ self .post = IndividualCache (
77104 # Internal specs requires throttling on at least token endpoint,
78105 # here we have a generic patch for POST on all endpoints.
79106 mapping = self ._expiring_mapping ,
@@ -91,9 +118,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
91118 _extract_data (kwargs , "username" )))), # "account" of ROPC
92119 ),
93120 expires_in = RetryAfterParser (default_throttle_time or 5 ).parse ,
94- )(_post )
121+ )(self . post )
95122
96- _post = IndividualCache ( # It covers the "UI required cache"
123+ self . post = IndividualCache ( # It covers the "UI required cache"
97124 mapping = self ._expiring_mapping ,
98125 key_maker = lambda func , args , kwargs : "POST {} hash={} 400" .format (
99126 args [0 ], # It is the url, typically containing authority and tenant
@@ -125,12 +152,10 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
125152 isinstance (kwargs .get ("data" ), dict )
126153 and kwargs ["data" ].get ("grant_type" ) == DEVICE_AUTH_GRANT
127154 )
128- and "retry-after" not in set ( # Leave it to the Retry-After decorator
155+ and RetryAfterParser . FIELD_NAME_LOWER not in set ( # Otherwise leave it to the Retry-After decorator
129156 h .lower () for h in getattr (result , "headers" , {}).keys ())
130157 else 0 ,
131- )(_post )
132-
133- self .post = _post
158+ )(self .post )
134159
135160 self .get = IndividualCache ( # Typically those discovery GETs
136161 mapping = self ._expiring_mapping ,
@@ -140,9 +165,4 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
140165 ),
141166 expires_in = lambda result = None , ** ignored :
142167 3600 * 24 if 200 <= result .status_code < 300 else 0 ,
143- )(http_client .get )
144-
145- # The following 2 methods have been defined dynamically by __init__()
146- #def post(self, *args, **kwargs): pass
147- #def get(self, *args, **kwargs): pass
148-
168+ )(self .get )
0 commit comments