1+ import sys
12import base64
23import functools
34import hashlib
@@ -205,17 +206,31 @@ class _TokenState(Enum):
205206class Refreshable (TokenSource ):
206207 """A token source that supports refreshing expired tokens."""
207208
208- _EXECUTOR = ThreadPoolExecutor (max_workers = 10 )
209+ _EXECUTOR = None
210+ _EXECUTOR_LOCK = threading .Lock ()
209211 _DEFAULT_STALE_DURATION = timedelta (minutes = 3 )
210212
213+ @classmethod
214+ def _get_executor (cls ):
215+ """Lazy initialization of the ThreadPoolExecutor."""
216+ if cls ._EXECUTOR is None :
217+ with cls ._EXECUTOR_LOCK :
218+ if cls ._EXECUTOR is None :
219+ # This thread pool has multiple workers because it is shared by all instances of Refreshable.
220+ cls ._EXECUTOR = ThreadPoolExecutor (max_workers = 10 )
221+ return cls ._EXECUTOR
222+
211223 def __init__ (self ,
212224 token : Token = None ,
213225 disable_async : bool = True ,
214226 stale_duration : timedelta = _DEFAULT_STALE_DURATION ):
215- self ._lock = threading .Lock ()
216- self ._token = token
227+ # Config properties
217228 self ._stale_duration = stale_duration
218229 self ._disable_async = disable_async
230+ # Lock
231+ self ._lock = threading .Lock ()
232+ # Non Thread safe properties. Protected by the lock above.
233+ self ._token = token
219234 self ._is_refreshing = False
220235 self ._refresh_err = False
221236
@@ -231,8 +246,9 @@ def _async_token(self) -> Token:
231246 If the token is stale, triggers an asynchronous refresh.
232247 If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
233248 """
234- with self ._lock :
235- token_state = self ._token_state ()
249+ token_state = self ._token_state ()
250+
251+ with self ._lock :
236252 token = self ._token
237253
238254 if token_state == _TokenState .FRESH :
@@ -244,20 +260,22 @@ def _async_token(self) -> Token:
244260
245261 def _token_state (self ) -> _TokenState :
246262 """Returns the current state of the token."""
247- if not self ._token or not self ._token .valid :
248- return _TokenState .EXPIRED
249- if not self ._token .expiry :
263+ with self ._lock :
264+ if not self ._token or not self ._token .valid :
265+ return _TokenState .EXPIRED
266+ if not self ._token .expiry :
267+ return _TokenState .FRESH
268+
269+ lifespan = self ._token .expiry - datetime .now ()
270+ if lifespan < timedelta (seconds = 0 ):
271+ return _TokenState .EXPIRED
272+ if lifespan < self ._stale_duration :
273+ return _TokenState .STALE
250274 return _TokenState .FRESH
251275
252- lifespan = self ._token .expiry - datetime .now ()
253- if lifespan < timedelta (seconds = 0 ):
254- return _TokenState .EXPIRED
255- if lifespan < self ._stale_duration :
256- return _TokenState .STALE
257- return _TokenState .FRESH
258-
259276 def _blocking_token (self ) -> Token :
260277 """Returns a token, blocking if necessary to refresh it."""
278+ state = self ._token_state ()
261279 with self ._lock :
262280 # This is important to recover from potential previous failed attempts
263281 # to refresh the token asynchronously.
@@ -267,7 +285,7 @@ def _blocking_token(self) -> Token:
267285 # It's possible that the token got refreshed (either by a _blocking_refresh or
268286 # an _async_refresh call) while this particular call was waiting to acquire
269287 # the lock. This check avoids refreshing the token again in such cases.
270- if self . _token_state () != _TokenState .EXPIRED :
288+ if state != _TokenState .EXPIRED :
271289 return self ._token
272290
273291 self ._token = self .refresh ()
@@ -276,20 +294,29 @@ def _blocking_token(self) -> Token:
276294 def _trigger_async_refresh (self ):
277295 """Starts an asynchronous refresh if none is in progress."""
278296
279- # Note: _refresh_internal function is not thread safe.
280- # Only call it inside the lock.
281297 def _refresh_internal ():
298+ new_token : Token = None
282299 try :
283- self ._token = self .refresh ()
284- except Exception :
285- self ._refresh_err = True
286- finally :
300+ new_token = self .refresh ()
301+ except Exception as e :
302+ # This happens on a thread, so we don't want to propagate the error.
303+ # Instead, if there is no new_token for any reason, we will disable async refresh below
304+ # But we will do it inside the lock.
305+ logger .warning (f'Tried to refresh token asynchronously, but failed: { e } ' )
306+ pass
307+
308+ with self ._lock :
309+ if new_token is not None :
310+ self ._token = new_token
311+ else :
312+ self ._refresh_err = True
287313 self ._is_refreshing = False
314+
288315
289316 with self ._lock :
290317 if not self ._is_refreshing and not self ._refresh_err :
291318 self ._is_refreshing = True
292- self . _EXECUTOR .submit (_refresh_internal )
319+ Refreshable . _get_executor () .submit (_refresh_internal )
293320
294321 @abstractmethod
295322 def refresh (self ) -> Token :
0 commit comments