Skip to content

Commit 51b9228

Browse files
committed
Cleanup
1 parent 011890d commit 51b9228

File tree

2 files changed

+88
-120
lines changed

2 files changed

+88
-120
lines changed

databricks/sdk/oauth.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import urllib.parse
1010
import webbrowser
1111
from abc import abstractmethod
12+
from concurrent.futures import ThreadPoolExecutor
1213
from dataclasses import dataclass
1314
from datetime import datetime, timedelta
1415
from enum import Enum
1516
from http.server import BaseHTTPRequestHandler, HTTPServer
1617
from typing import Any, Dict, List, Optional
17-
from concurrent.futures import ThreadPoolExecutor
1818

1919
import requests
2020
import requests.auth
@@ -188,67 +188,67 @@ def retrieve_token(client_id,
188188
except Exception as e:
189189
raise NotImplementedError(f"Not supported yet: {e}")
190190

191+
191192
class _TokenState(Enum):
192193
"""
193-
tokenState represents the state of the token. Each token can be in one of
194+
Represents the state of a token. Each token can be in one of
194195
the following three states:
195196
- FRESH: The token is valid.
196197
- STALE: The token is valid but will expire soon.
197198
- EXPIRED: The token has expired and cannot be used.
198-
199-
Token state through time:
200-
issue time expiry time
201-
v v
202-
| fresh | stale | expired -> time
203-
| valid |
204199
"""
205200
FRESH = 1 # The token is valid.
206201
STALE = 2 # The token is valid but will expire soon.
207202
EXPIRED = 3 # The token has expired and cannot be used.
208203

209204

210205
class Refreshable(TokenSource):
211-
_executor = ThreadPoolExecutor(max_workers=10)
212-
_default_stale_duration = 3
206+
"""A token source that supports refreshing expired tokens."""
207+
208+
_EXECUTOR = ThreadPoolExecutor(max_workers=10)
209+
_DEFAULT_STALE_DURATION = timedelta(minutes=3)
213210

214-
def __init__(self, token=None, disable_async = True, stale_duration=timedelta(minutes=_default_stale_duration)):
215-
self._lock = threading.Lock() # to guard _token
211+
def __init__(self,
212+
token: Token = None,
213+
disable_async: bool = True,
214+
stale_duration: timedelta = _DEFAULT_STALE_DURATION):
215+
self._lock = threading.Lock()
216216
self._token = token
217217
self._stale_duration = stale_duration
218218
self._disable_async = disable_async
219219
self._is_refreshing = False
220220
self._refresh_err = False
221221

222-
def token(self, blocking=False) -> Token:
222+
def token(self) -> Token:
223+
"""Returns a valid token, blocking if async refresh is disabled."""
223224
if self._disable_async:
224225
return self._blocking_token()
225226
return self._async_token()
226227

227228
def _async_token(self) -> Token:
228-
self._lock.acquire()
229-
token_state = self._token_state()
230-
token = self._token
231-
self._lock.release()
232-
match token_state:
233-
case _TokenState.FRESH:
234-
return token
235-
case _TokenState.STALE:
236-
self._trigger_async_refresh()
237-
return token
238-
case _: #Expired
239-
return self._blocking_token()
229+
"""
230+
Returns a token.
231+
If the token is stale, triggers an asynchronous refresh.
232+
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
233+
"""
234+
with self._lock:
235+
token_state = self._token_state()
236+
token = self._token
240237

238+
if token_state == _TokenState.FRESH:
239+
return token
240+
if token_state == _TokenState.STALE:
241+
self._trigger_async_refresh()
242+
return token
243+
return self._blocking_token()
241244

242245
def _token_state(self) -> _TokenState:
243-
"""
244-
Returns the state of the token.
245-
"""
246-
# Invalid tokens are considered expired.
246+
"""Returns the current state of the token."""
247247
if not self._token or not self._token.valid:
248248
return _TokenState.EXPIRED
249-
# Tokens without an expiry are considered always.
250249
if not self._token.expiry:
251250
return _TokenState.FRESH
251+
252252
lifespan = self._token.expiry - datetime.now()
253253
if lifespan < timedelta(seconds=0):
254254
return _TokenState.EXPIRED
@@ -257,13 +257,10 @@ def _token_state(self) -> _TokenState:
257257
return _TokenState.FRESH
258258

259259
def _blocking_token(self) -> Token:
260-
261-
# The lock is kept for the entire operation to ensure that only one
262-
# refresh operation is running at a time.
260+
"""Returns a token, blocking if necessary to refresh it."""
263261
with self._lock:
264262
# This is important to recover from potential previous failed attempts
265-
# to refresh the token asynchronously, see declaration of refresh_err for
266-
# more information.
263+
# to refresh the token asynchronously.
267264
self._refresh_err = False
268265
self._is_refreshing = False
269266

@@ -273,13 +270,13 @@ def _blocking_token(self) -> Token:
273270
if self._token_state() != _TokenState.EXPIRED:
274271
return self._token
275272

276-
# Refresh the token
277273
self._token = self.refresh()
278274
return self._token
279275

280-
281276
def _trigger_async_refresh(self):
282-
# Note: this is not thread safe.
277+
"""Starts an asynchronous refresh if none is in progress."""
278+
279+
# Note: _refresh_internal function is not thread safe.
283280
# Only call it inside the lock.
284281
def _refresh_internal():
285282
try:
@@ -288,12 +285,11 @@ def _refresh_internal():
288285
self._refresh_err = True
289286
finally:
290287
self._is_refreshing = False
291-
# The lock is kept for the entire operation to ensure that only one
292-
# refresh operation is running at a time.
288+
293289
with self._lock:
294290
if not self._is_refreshing and not self._refresh_err:
295291
self._is_refreshing = True
296-
self._executor.submit(_refresh_internal)
292+
self._EXECUTOR.submit(_refresh_internal)
297293

298294
@abstractmethod
299295
def refresh(self) -> Token:

0 commit comments

Comments
 (0)