99import urllib .parse
1010import webbrowser
1111from abc import abstractmethod
12+ from concurrent .futures import ThreadPoolExecutor
1213from dataclasses import dataclass
1314from datetime import datetime , timedelta
1415from enum import Enum
1516from http .server import BaseHTTPRequestHandler , HTTPServer
1617from typing import Any , Dict , List , Optional
17- from concurrent .futures import ThreadPoolExecutor
1818
1919import requests
2020import requests .auth
@@ -188,67 +188,67 @@ def retrieve_token(client_id,
188188 except Exception as e :
189189 raise NotImplementedError (f"Not supported yet: { e } " )
190190
191+
191192class _TokenState (Enum ):
192193 """
193- tokenState represents the state of the token. Each token can be in one of
194+ Represents the state of a token. Each token can be in one of
194195 the following three states:
195196 - FRESH: The token is valid.
196197 - STALE: The token is valid but will expire soon.
197198 - EXPIRED: The token has expired and cannot be used.
198-
199- Token state through time:
200- issue time expiry time
201- v v
202- | fresh | stale | expired -> time
203- | valid |
204199 """
205200 FRESH = 1 # The token is valid.
206201 STALE = 2 # The token is valid but will expire soon.
207202 EXPIRED = 3 # The token has expired and cannot be used.
208203
209204
210205class Refreshable (TokenSource ):
211- _executor = ThreadPoolExecutor (max_workers = 10 )
212- _default_stale_duration = 3
206+ """A token source that supports refreshing expired tokens."""
207+
208+ _EXECUTOR = ThreadPoolExecutor (max_workers = 10 )
209+ _DEFAULT_STALE_DURATION = timedelta (minutes = 3 )
213210
214- def __init__ (self , token = None , disable_async = True , stale_duration = timedelta (minutes = _default_stale_duration )):
215- self ._lock = threading .Lock () # to guard _token
211+ def __init__ (self ,
212+ token : Token = None ,
213+ disable_async : bool = True ,
214+ stale_duration : timedelta = _DEFAULT_STALE_DURATION ):
215+ self ._lock = threading .Lock ()
216216 self ._token = token
217217 self ._stale_duration = stale_duration
218218 self ._disable_async = disable_async
219219 self ._is_refreshing = False
220220 self ._refresh_err = False
221221
222- def token (self , blocking = False ) -> Token :
222+ def token (self ) -> Token :
223+ """Returns a valid token, blocking if async refresh is disabled."""
223224 if self ._disable_async :
224225 return self ._blocking_token ()
225226 return self ._async_token ()
226227
227228 def _async_token (self ) -> Token :
228- self ._lock .acquire ()
229- token_state = self ._token_state ()
230- token = self ._token
231- self ._lock .release ()
232- match token_state :
233- case _TokenState .FRESH :
234- return token
235- case _TokenState .STALE :
236- self ._trigger_async_refresh ()
237- return token
238- case _: #Expired
239- return self ._blocking_token ()
229+ """
230+ Returns a token.
231+ If the token is stale, triggers an asynchronous refresh.
232+ If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
233+ """
234+ with self ._lock :
235+ token_state = self ._token_state ()
236+ token = self ._token
240237
238+ if token_state == _TokenState .FRESH :
239+ return token
240+ if token_state == _TokenState .STALE :
241+ self ._trigger_async_refresh ()
242+ return token
243+ return self ._blocking_token ()
241244
242245 def _token_state (self ) -> _TokenState :
243- """
244- Returns the state of the token.
245- """
246- # Invalid tokens are considered expired.
246+ """Returns the current state of the token."""
247247 if not self ._token or not self ._token .valid :
248248 return _TokenState .EXPIRED
249- # Tokens without an expiry are considered always.
250249 if not self ._token .expiry :
251250 return _TokenState .FRESH
251+
252252 lifespan = self ._token .expiry - datetime .now ()
253253 if lifespan < timedelta (seconds = 0 ):
254254 return _TokenState .EXPIRED
@@ -257,13 +257,10 @@ def _token_state(self) -> _TokenState:
257257 return _TokenState .FRESH
258258
259259 def _blocking_token (self ) -> Token :
260-
261- # The lock is kept for the entire operation to ensure that only one
262- # refresh operation is running at a time.
260+ """Returns a token, blocking if necessary to refresh it."""
263261 with self ._lock :
264262 # This is important to recover from potential previous failed attempts
265- # to refresh the token asynchronously, see declaration of refresh_err for
266- # more information.
263+ # to refresh the token asynchronously.
267264 self ._refresh_err = False
268265 self ._is_refreshing = False
269266
@@ -273,13 +270,13 @@ def _blocking_token(self) -> Token:
273270 if self ._token_state () != _TokenState .EXPIRED :
274271 return self ._token
275272
276- # Refresh the token
277273 self ._token = self .refresh ()
278274 return self ._token
279275
280-
281276 def _trigger_async_refresh (self ):
282- # Note: this is not thread safe.
277+ """Starts an asynchronous refresh if none is in progress."""
278+
279+ # Note: _refresh_internal function is not thread safe.
283280 # Only call it inside the lock.
284281 def _refresh_internal ():
285282 try :
@@ -288,12 +285,11 @@ def _refresh_internal():
288285 self ._refresh_err = True
289286 finally :
290287 self ._is_refreshing = False
291- # The lock is kept for the entire operation to ensure that only one
292- # refresh operation is running at a time.
288+
293289 with self ._lock :
294290 if not self ._is_refreshing and not self ._refresh_err :
295291 self ._is_refreshing = True
296- self ._executor .submit (_refresh_internal )
292+ self ._EXECUTOR .submit (_refresh_internal )
297293
298294 @abstractmethod
299295 def refresh (self ) -> Token :
0 commit comments