@@ -45,25 +45,42 @@ def _extract_data(kwargs, key, default=None):
4545 return data .get (key ) if isinstance (data , dict ) else default
4646
4747
48- class ThrottledHttpClient (object ):
49- def __init__ (self , http_client , http_cache ):
50- """Throttle the given http_client by storing and retrieving data from cache.
48+ class ThrottledHttpClientBase (object ):
49+ """Throttle the given http_client by storing and retrieving data from cache.
5150
52- This wrapper exists so that our patching post() and get() would prevent
53- re-patching side effect when/if same http_client being reused.
54- """
55- expiring_mapping = ExpiringMapping ( # It will automatically clean up
51+ This wrapper exists so that our patching post() and get() would prevent
52+ re-patching side effect when/if same http_client being reused.
53+
54+ The subclass should implement post() and/or get()
55+ """
56+ def __init__ (self , http_client , http_cache ):
57+ self .http_client = http_client
58+ self ._expiring_mapping = ExpiringMapping ( # It will automatically clean up
5659 mapping = http_cache if http_cache is not None else {},
5760 capacity = 1024 , # To prevent cache blowing up especially for CCA
5861 lock = Lock (), # TODO: This should ideally also allow customization
5962 )
6063
64+ def post (self , * args , ** kwargs ):
65+ return self .http_client .post (* args , ** kwargs )
66+
67+ def get (self , * args , ** kwargs ):
68+ return self .http_client .get (* args , ** kwargs )
69+
70+ def close (self ):
71+ return self .http_client .close ()
72+
73+
74+ class ThrottledHttpClient (ThrottledHttpClientBase ):
75+ def __init__ (self , http_client , http_cache ):
76+ super (ThrottledHttpClient , self ).__init__ (http_client , http_cache )
77+
6178 _post = http_client .post # We'll patch _post, and keep original post() intact
6279
6380 _post = IndividualCache (
6481 # Internal specs requires throttling on at least token endpoint,
6582 # here we have a generic patch for POST on all endpoints.
66- mapping = expiring_mapping ,
83+ mapping = self . _expiring_mapping ,
6784 key_maker = lambda func , args , kwargs :
6885 "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After" .format (
6986 args [0 ], # It is the url, typically containing authority and tenant
@@ -81,7 +98,7 @@ def __init__(self, http_client, http_cache):
8198 )(_post )
8299
83100 _post = IndividualCache ( # It covers the "UI required cache"
84- mapping = expiring_mapping ,
101+ mapping = self . _expiring_mapping ,
85102 key_maker = lambda func , args , kwargs : "POST {} hash={} 400" .format (
86103 args [0 ], # It is the url, typically containing authority and tenant
87104 _hash (
@@ -120,7 +137,7 @@ def __init__(self, http_client, http_cache):
120137 self .post = _post
121138
122139 self .get = IndividualCache ( # Typically those discovery GETs
123- mapping = expiring_mapping ,
140+ mapping = self . _expiring_mapping ,
124141 key_maker = lambda func , args , kwargs : "GET {} hash={} 2xx" .format (
125142 args [0 ], # It is the url, sometimes containing inline params
126143 _hash (kwargs .get ("params" , "" )),
@@ -129,13 +146,7 @@ def __init__(self, http_client, http_cache):
129146 3600 * 24 if 200 <= result .status_code < 300 else 0 ,
130147 )(http_client .get )
131148
132- self ._http_client = http_client
133-
134149 # The following 2 methods have been defined dynamically by __init__()
135150 #def post(self, *args, **kwargs): pass
136151 #def get(self, *args, **kwargs): pass
137152
138- def close (self ):
139- """MSAL won't need this. But we allow throttled_http_client.close() anyway"""
140- return self ._http_client .close ()
141-
0 commit comments