Skip to content

Commit 949fd00

Browse files
committed
Cleanup
1 parent 011890d commit 949fd00

File tree

1 file changed

+38
-44
lines changed

1 file changed

+38
-44
lines changed

databricks/sdk/oauth.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -190,65 +190,64 @@ def retrieve_token(client_id,
190190

191191
class _TokenState(Enum):
192192
"""
193-
tokenState represents the state of the token. Each token can be in one of
193+
Represents the state of a token. Each token can be in one of
194194
the following three states:
195195
- FRESH: The token is valid.
196196
- STALE: The token is valid but will expire soon.
197197
- EXPIRED: The token has expired and cannot be used.
198-
199-
Token state through time:
200-
issue time expiry time
201-
v v
202-
| fresh | stale | expired -> time
203-
| valid |
204198
"""
205-
FRESH = 1 # The token is valid.
206-
STALE = 2 # The token is valid but will expire soon.
207-
EXPIRED = 3 # The token has expired and cannot be used.
208-
199+
FRESH = 1 # The token is valid.
200+
STALE = 2 # The token is valid but will expire soon.
201+
EXPIRED = 3 # The token has expired and cannot be used.
209202

210203
class Refreshable(TokenSource):
211-
_executor = ThreadPoolExecutor(max_workers=10)
212-
_default_stale_duration = 3
204+
"""A token source that supports refreshing expired tokens."""
213205

214-
def __init__(self, token=None, disable_async = True, stale_duration=timedelta(minutes=_default_stale_duration)):
215-
self._lock = threading.Lock() # to guard _token
206+
_EXECUTOR = ThreadPoolExecutor(max_workers=10)
207+
_DEFAULT_STALE_DURATION = timedelta(minutes=3)
208+
209+
def __init__(
210+
self,
211+
token: Token = None,
212+
disable_async: bool = True,
213+
stale_duration: timedelta = _DEFAULT_STALE_DURATION):
214+
self._lock = threading.Lock()
216215
self._token = token
217216
self._stale_duration = stale_duration
218217
self._disable_async = disable_async
219218
self._is_refreshing = False
220219
self._refresh_err = False
221220

222-
def token(self, blocking=False) -> Token:
221+
def token(self) -> Token:
222+
"""Returns a valid token, blocking if async refresh is disabled."""
223223
if self._disable_async:
224224
return self._blocking_token()
225225
return self._async_token()
226226

227227
def _async_token(self) -> Token:
228-
self._lock.acquire()
229-
token_state = self._token_state()
230-
token = self._token
231-
self._lock.release()
232-
match token_state:
233-
case _TokenState.FRESH:
234-
return token
235-
case _TokenState.STALE:
236-
self._trigger_async_refresh()
237-
return token
238-
case _: #Expired
239-
return self._blocking_token()
228+
"""
229+
Returns a token.
230+
If the token is stale, triggers an asynchronous refresh.
231+
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
232+
"""
233+
with self._lock:
234+
token_state = self._token_state()
235+
token = self._token
240236

237+
if token_state == _TokenState.FRESH:
238+
return token
239+
if token_state == _TokenState.STALE:
240+
self._trigger_async_refresh()
241+
return token
242+
return self._blocking_token()
241243

242244
def _token_state(self) -> _TokenState:
243-
"""
244-
Returns the state of the token.
245-
"""
246-
# Invalid tokens are considered expired.
245+
"""Returns the current state of the token."""
247246
if not self._token or not self._token.valid:
248247
return _TokenState.EXPIRED
249-
# Tokens without an expiry are considered always.
250248
if not self._token.expiry:
251249
return _TokenState.FRESH
250+
252251
lifespan = self._token.expiry - datetime.now()
253252
if lifespan < timedelta(seconds=0):
254253
return _TokenState.EXPIRED
@@ -257,13 +256,10 @@ def _token_state(self) -> _TokenState:
257256
return _TokenState.FRESH
258257

259258
def _blocking_token(self) -> Token:
260-
261-
# The lock is kept for the entire operation to ensure that only one
262-
# refresh operation is running at a time.
259+
"""Returns a token, blocking if necessary to refresh it."""
263260
with self._lock:
264261
# This is important to recover from potential previous failed attempts
265-
# to refresh the token asynchronously, see declaration of refresh_err for
266-
# more information.
262+
# to refresh the token asynchronously.
267263
self._refresh_err = False
268264
self._is_refreshing = False
269265

@@ -273,13 +269,12 @@ def _blocking_token(self) -> Token:
273269
if self._token_state() != _TokenState.EXPIRED:
274270
return self._token
275271

276-
# Refresh the token
277272
self._token = self.refresh()
278273
return self._token
279274

280-
281275
def _trigger_async_refresh(self):
282-
# Note: this is not thread safe.
276+
"""Starts an asynchronous refresh if none is in progress."""
277+
# Note: _refresh_internal function is not thread safe.
283278
# Only call it inside the lock.
284279
def _refresh_internal():
285280
try:
@@ -288,12 +283,11 @@ def _refresh_internal():
288283
self._refresh_err = True
289284
finally:
290285
self._is_refreshing = False
291-
# The lock is kept for the entire operation to ensure that only one
292-
# refresh operation is running at a time.
286+
293287
with self._lock:
294288
if not self._is_refreshing and not self._refresh_err:
295289
self._is_refreshing = True
296-
self._executor.submit(_refresh_internal)
290+
self._EXECUTOR.submit(_refresh_internal)
297291

298292
@abstractmethod
299293
def refresh(self) -> Token:

0 commit comments

Comments
 (0)