@@ -205,17 +205,31 @@ class _TokenState(Enum):
205205class Refreshable (TokenSource ):
206206 """A token source that supports refreshing expired tokens."""
207207
208- _EXECUTOR = ThreadPoolExecutor (max_workers = 10 )
208+ _EXECUTOR = None
209+ _EXECUTOR_LOCK = threading .Lock ()
209210 _DEFAULT_STALE_DURATION = timedelta (minutes = 3 )
210211
212+ @classmethod
213+ def _get_executor (cls ):
214+ """Lazy initialization of the ThreadPoolExecutor."""
215+ if cls ._EXECUTOR is None :
216+ with cls ._EXECUTOR_LOCK :
217+ if cls ._EXECUTOR is None :
218+ # This thread pool has multiple workers because it is shared by all instances of Refreshable.
219+ cls ._EXECUTOR = ThreadPoolExecutor (max_workers = 10 )
220+ return cls ._EXECUTOR
221+
211222 def __init__ (self ,
212223 token : Token = None ,
213224 disable_async : bool = True ,
214225 stale_duration : timedelta = _DEFAULT_STALE_DURATION ):
215- self ._lock = threading .Lock ()
216- self ._token = token
226+ # Config properties
217227 self ._stale_duration = stale_duration
218228 self ._disable_async = disable_async
229+ # Lock
230+ self ._lock = threading .Lock ()
231+ # Non Thread safe properties. Protected by the lock above.
232+ self ._token = token
219233 self ._is_refreshing = False
220234 self ._refresh_err = False
221235
@@ -232,33 +246,37 @@ def _async_token(self) -> Token:
232246 If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
233247 """
234248 with self ._lock :
235- token_state = self ._token_state ()
249+ state = Refreshable ._token_state (self . _token , self . _stale_duration )
236250 token = self ._token
237251
238- if token_state == _TokenState .FRESH :
252+ if state == _TokenState .FRESH :
239253 return token
240- if token_state == _TokenState .STALE :
254+ if state == _TokenState .STALE :
241255 self ._trigger_async_refresh ()
242256 return token
243257 return self ._blocking_token ()
244258
245- def _token_state (self ) -> _TokenState :
259+ # This is a class method and we pass the token to avoid
260+ # concurrency issues and deadlocks.
261+ @classmethod
262+ def _token_state (cls , token : Token , stale_duration : timedelta ) -> _TokenState :
246263 """Returns the current state of the token."""
247- if not self . _token or not self . _token .valid :
264+ if not token or not token .valid :
248265 return _TokenState .EXPIRED
249- if not self . _token .expiry :
266+ if not token .expiry :
250267 return _TokenState .FRESH
251268
252- lifespan = self . _token .expiry - datetime .now ()
269+ lifespan = token .expiry - datetime .now ()
253270 if lifespan < timedelta (seconds = 0 ):
254271 return _TokenState .EXPIRED
255- if lifespan < self . _stale_duration :
272+ if lifespan < stale_duration :
256273 return _TokenState .STALE
257274 return _TokenState .FRESH
258275
259276 def _blocking_token (self ) -> Token :
260277 """Returns a token, blocking if necessary to refresh it."""
261278 with self ._lock :
279+ state = Refreshable ._token_state (self ._token , self ._stale_duration )
262280 # This is important to recover from potential previous failed attempts
263281 # to refresh the token asynchronously.
264282 self ._refresh_err = False
@@ -267,7 +285,7 @@ def _blocking_token(self) -> Token:
267285 # It's possible that the token got refreshed (either by a _blocking_refresh or
268286 # an _async_refresh call) while this particular call was waiting to acquire
269287 # the lock. This check avoids refreshing the token again in such cases.
270- if self . _token_state () != _TokenState .EXPIRED :
288+ if state != _TokenState .EXPIRED :
271289 return self ._token
272290
273291 self ._token = self .refresh ()
@@ -276,20 +294,31 @@ def _blocking_token(self) -> Token:
276294 def _trigger_async_refresh (self ):
277295 """Starts an asynchronous refresh if none is in progress."""
278296
279- # Note: _refresh_internal function is not thread safe.
280- # Only call it inside the lock.
281297 def _refresh_internal ():
298+ new_token : Token = None
282299 try :
283- self ._token = self .refresh ()
284- except Exception :
285- self ._refresh_err = True
286- finally :
300+ new_token = self .refresh ()
301+ except Exception as e :
302+ # This happens on a thread, so we don't want to propagate the error.
303+ # Instead, if there is no new_token for any reason, we will disable async refresh below
304+ # But we will do it inside the lock.
305+ logger .warning (f'Tried to refresh token asynchronously, but failed: { e } ' )
306+
307+ with self ._lock :
308+ if new_token is not None :
309+ self ._token = new_token
310+ else :
311+ self ._refresh_err = True
287312 self ._is_refreshing = False
288313
289314 with self ._lock :
315+ state = Refreshable ._token_state (self ._token , self ._stale_duration )
316+ # The token may have been refreshed by another thread.
317+ if state == _TokenState .FRESH :
318+ return
290319 if not self ._is_refreshing and not self ._refresh_err :
291320 self ._is_refreshing = True
292- self . _EXECUTOR .submit (_refresh_internal )
321+ Refreshable . _get_executor () .submit (_refresh_internal )
293322
294323 @abstractmethod
295324 def refresh (self ) -> Token :
0 commit comments