33
44from .individual_cache import _IndividualCache as IndividualCache
55from .individual_cache import _ExpiringMapping as ExpiringMapping
6+ from .oauth2cli .http import Response
7+ from .exceptions import MsalServiceError
68
79
810# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
@@ -67,13 +69,30 @@ def _hash(raw):
6769 return sha256 (repr (raw ).encode ("utf-8" )).hexdigest ()
6870
6971
72+ class NormalizedResponse (Response ):
73+ """A http response with the shape defined in Response,
74+ but contains only the data we will store in cache.
75+ """
76+ def __init__ (self , raw_response ):
77+ super ().__init__ ()
78+ self .status_code = raw_response .status_code
79+ self .text = raw_response .text
80+ self .headers = raw_response .headers
81+
82+ ## Note: Don't use the following line,
83+ ## because when being pickled, it will indirectly pickle the whole raw_response
84+ # self.raise_for_status = raw_response.raise_for_status
85+ def raise_for_status (self ):
86+ if self .status_code >= 400 :
87+ raise MsalServiceError ("HTTP Error: {}" .format (self .status_code ))
88+
89+
7090class ThrottledHttpClient (ThrottledHttpClientBase ):
91+ """A throttled http client wrapper that is tailored for MSAL."""
7192 def __init__ (self , http_client , * , default_throttle_time = None , ** kwargs ):
93+ """Decorate self.post() and self.get() dynamically"""
7294 super (ThrottledHttpClient , self ).__init__ (http_client , ** kwargs )
73-
74- _post = http_client .post # We'll patch _post, and keep original post() intact
75-
76- _post = IndividualCache (
95+ self .post = IndividualCache (
7796 # Internal specs requires throttling on at least token endpoint,
7897 # here we have a generic patch for POST on all endpoints.
7998 mapping = self ._expiring_mapping ,
@@ -91,9 +110,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
91110 _extract_data (kwargs , "username" )))), # "account" of ROPC
92111 ),
93112 expires_in = RetryAfterParser (default_throttle_time or 5 ).parse ,
94- )(_post )
113+ )(self . post )
95114
96- _post = IndividualCache ( # It covers the "UI required cache"
115+ self . post = IndividualCache ( # It covers the "UI required cache"
97116 mapping = self ._expiring_mapping ,
98117 key_maker = lambda func , args , kwargs : "POST {} hash={} 400" .format (
99118 args [0 ], # It is the url, typically containing authority and tenant
@@ -128,9 +147,7 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
128147 and "retry-after" not in set ( # Leave it to the Retry-After decorator
129148 h .lower () for h in getattr (result , "headers" , {}).keys ())
130149 else 0 ,
131- )(_post )
132-
133- self .post = _post
150+ )(self .post )
134151
135152 self .get = IndividualCache ( # Typically those discovery GETs
136153 mapping = self ._expiring_mapping ,
@@ -140,9 +157,10 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
140157 ),
141158 expires_in = lambda result = None , ** ignored :
142159 3600 * 24 if 200 <= result .status_code < 300 else 0 ,
143- )(http_client .get )
160+ )(self .get )
144161
145- # The following 2 methods have been defined dynamically by __init__()
146- #def post(self, *args, **kwargs): pass
147- #def get(self, *args, **kwargs): pass
162+ def post (self , * args , ** kwargs ):
163+ return NormalizedResponse (super (ThrottledHttpClient , self ).post (* args , ** kwargs ))
148164
165+ def get (self , * args , ** kwargs ):
166+ return NormalizedResponse (super (ThrottledHttpClient , self ).get (* args , ** kwargs ))
0 commit comments