@@ -190,65 +190,64 @@ def retrieve_token(client_id,
190190
191191class _TokenState (Enum ):
192192 """
193- tokenState represents the state of the token. Each token can be in one of
193+ Represents the state of a token. Each token can be in one of
194194 the following three states:
195195 - FRESH: The token is valid.
196196 - STALE: The token is valid but will expire soon.
197197 - EXPIRED: The token has expired and cannot be used.
198-
199- Token state through time:
200- issue time expiry time
201- v v
202- | fresh | stale | expired -> time
203- | valid |
204198 """
205- FRESH = 1 # The token is valid.
206- STALE = 2 # The token is valid but will expire soon.
207- EXPIRED = 3 # The token has expired and cannot be used.
208-
199+ FRESH = 1 # The token is valid.
200+ STALE = 2 # The token is valid but will expire soon.
201+ EXPIRED = 3 # The token has expired and cannot be used.
209202
210203class Refreshable (TokenSource ):
211- _executor = ThreadPoolExecutor (max_workers = 10 )
212- _default_stale_duration = 3
204+ """A token source that supports refreshing expired tokens."""
213205
214- def __init__ (self , token = None , disable_async = True , stale_duration = timedelta (minutes = _default_stale_duration )):
215- self ._lock = threading .Lock () # to guard _token
206+ _EXECUTOR = ThreadPoolExecutor (max_workers = 10 )
207+ _DEFAULT_STALE_DURATION = timedelta (minutes = 3 )
208+
209+ def __init__ (
210+ self ,
211+ token : Token = None ,
212+ disable_async : bool = True ,
213+ stale_duration : timedelta = _DEFAULT_STALE_DURATION ):
214+ self ._lock = threading .Lock ()
216215 self ._token = token
217216 self ._stale_duration = stale_duration
218217 self ._disable_async = disable_async
219218 self ._is_refreshing = False
220219 self ._refresh_err = False
221220
222- def token (self , blocking = False ) -> Token :
221+ def token (self ) -> Token :
222+ """Returns a valid token, blocking if async refresh is disabled."""
223223 if self ._disable_async :
224224 return self ._blocking_token ()
225225 return self ._async_token ()
226226
227227 def _async_token (self ) -> Token :
228- self ._lock .acquire ()
229- token_state = self ._token_state ()
230- token = self ._token
231- self ._lock .release ()
232- match token_state :
233- case _TokenState .FRESH :
234- return token
235- case _TokenState .STALE :
236- self ._trigger_async_refresh ()
237- return token
238- case _: #Expired
239- return self ._blocking_token ()
228+ """
229+ Returns a token.
230+ If the token is stale, triggers an asynchronous refresh.
231+ If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
232+ """
233+ with self ._lock :
234+ token_state = self ._token_state ()
235+ token = self ._token
240236
237+ if token_state == _TokenState .FRESH :
238+ return token
239+ if token_state == _TokenState .STALE :
240+ self ._trigger_async_refresh ()
241+ return token
242+ return self ._blocking_token ()
241243
242244 def _token_state (self ) -> _TokenState :
243- """
244- Returns the state of the token.
245- """
246- # Invalid tokens are considered expired.
245+ """Returns the current state of the token."""
247246 if not self ._token or not self ._token .valid :
248247 return _TokenState .EXPIRED
249- # Tokens without an expiry are considered always.
250248 if not self ._token .expiry :
251249 return _TokenState .FRESH
250+
252251 lifespan = self ._token .expiry - datetime .now ()
253252 if lifespan < timedelta (seconds = 0 ):
254253 return _TokenState .EXPIRED
@@ -257,13 +256,10 @@ def _token_state(self) -> _TokenState:
257256 return _TokenState .FRESH
258257
259258 def _blocking_token (self ) -> Token :
260-
261- # The lock is kept for the entire operation to ensure that only one
262- # refresh operation is running at a time.
259+ """Returns a token, blocking if necessary to refresh it."""
263260 with self ._lock :
264261 # This is important to recover from potential previous failed attempts
265- # to refresh the token asynchronously, see declaration of refresh_err for
266- # more information.
262+ # to refresh the token asynchronously.
267263 self ._refresh_err = False
268264 self ._is_refreshing = False
269265
@@ -273,13 +269,12 @@ def _blocking_token(self) -> Token:
273269 if self ._token_state () != _TokenState .EXPIRED :
274270 return self ._token
275271
276- # Refresh the token
277272 self ._token = self .refresh ()
278273 return self ._token
279274
280-
281275 def _trigger_async_refresh (self ):
282- # Note: this is not thread safe.
276+ """Starts an asynchronous refresh if none is in progress."""
277+ # Note: _refresh_internal function is not thread safe.
283278 # Only call it inside the lock.
284279 def _refresh_internal ():
285280 try :
@@ -288,12 +283,11 @@ def _refresh_internal():
288283 self ._refresh_err = True
289284 finally :
290285 self ._is_refreshing = False
291- # The lock is kept for the entire operation to ensure that only one
292- # refresh operation is running at a time.
286+
293287 with self ._lock :
294288 if not self ._is_refreshing and not self ._refresh_err :
295289 self ._is_refreshing = True
296- self ._executor .submit (_refresh_internal )
290+ self ._EXECUTOR .submit (_refresh_internal )
297291
298292 @abstractmethod
299293 def refresh (self ) -> Token :
0 commit comments