Skip to content

Commit 21c999f

Browse files
committed
Lazy load the executor
1 parent 51b9228 commit 21c999f

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

databricks/sdk/oauth.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -205,17 +205,31 @@ class _TokenState(Enum):
205205
class Refreshable(TokenSource):
206206
"""A token source that supports refreshing expired tokens."""
207207

208-
_EXECUTOR = ThreadPoolExecutor(max_workers=10)
208+
_EXECUTOR = None
209+
_EXECUTOR_LOCK = threading.Lock()
209210
_DEFAULT_STALE_DURATION = timedelta(minutes=3)
210211

212+
@classmethod
213+
def _get_executor(cls):
214+
"""Lazy initialization of the ThreadPoolExecutor."""
215+
if cls._EXECUTOR is None:
216+
with cls._EXECUTOR_LOCK:
217+
if cls._EXECUTOR is None:
218+
# This thread pool has multiple workers because it is shared by all instances of Refreshable.
219+
cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
220+
return cls._EXECUTOR
221+
211222
def __init__(self,
212223
token: Token = None,
213224
disable_async: bool = True,
214225
stale_duration: timedelta = _DEFAULT_STALE_DURATION):
215-
self._lock = threading.Lock()
216-
self._token = token
226+
# Config properties
217227
self._stale_duration = stale_duration
218228
self._disable_async = disable_async
229+
# Lock
230+
self._lock = threading.Lock()
231+
# Non Thread safe properties. Protected by the lock above.
232+
self._token = token
219233
self._is_refreshing = False
220234
self._refresh_err = False
221235

@@ -232,33 +246,37 @@ def _async_token(self) -> Token:
232246
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
233247
"""
234248
with self._lock:
235-
token_state = self._token_state()
249+
state = Refreshable._token_state(self._token, self._stale_duration)
236250
token = self._token
237251

238-
if token_state == _TokenState.FRESH:
252+
if state == _TokenState.FRESH:
239253
return token
240-
if token_state == _TokenState.STALE:
254+
if state == _TokenState.STALE:
241255
self._trigger_async_refresh()
242256
return token
243257
return self._blocking_token()
244258

245-
def _token_state(self) -> _TokenState:
259+
# This is a class method and we pass the token to avoid
260+
# concurrency issues and deadlocks.
261+
@classmethod
262+
def _token_state(cls, token: Token, stale_duration: timedelta) -> _TokenState:
246263
"""Returns the current state of the token."""
247-
if not self._token or not self._token.valid:
264+
if not token or not token.valid:
248265
return _TokenState.EXPIRED
249-
if not self._token.expiry:
266+
if not token.expiry:
250267
return _TokenState.FRESH
251268

252-
lifespan = self._token.expiry - datetime.now()
269+
lifespan = token.expiry - datetime.now()
253270
if lifespan < timedelta(seconds=0):
254271
return _TokenState.EXPIRED
255-
if lifespan < self._stale_duration:
272+
if lifespan < stale_duration:
256273
return _TokenState.STALE
257274
return _TokenState.FRESH
258275

259276
def _blocking_token(self) -> Token:
260277
"""Returns a token, blocking if necessary to refresh it."""
261278
with self._lock:
279+
state = Refreshable._token_state(self._token, self._stale_duration)
262280
# This is important to recover from potential previous failed attempts
263281
# to refresh the token asynchronously.
264282
self._refresh_err = False
@@ -267,7 +285,7 @@ def _blocking_token(self) -> Token:
267285
# It's possible that the token got refreshed (either by a _blocking_refresh or
268286
# an _async_refresh call) while this particular call was waiting to acquire
269287
# the lock. This check avoids refreshing the token again in such cases.
270-
if self._token_state() != _TokenState.EXPIRED:
288+
if state != _TokenState.EXPIRED:
271289
return self._token
272290

273291
self._token = self.refresh()
@@ -276,20 +294,31 @@ def _blocking_token(self) -> Token:
276294
def _trigger_async_refresh(self):
277295
"""Starts an asynchronous refresh if none is in progress."""
278296

279-
# Note: _refresh_internal function is not thread safe.
280-
# Only call it inside the lock.
281297
def _refresh_internal():
298+
new_token: Token = None
282299
try:
283-
self._token = self.refresh()
284-
except Exception:
285-
self._refresh_err = True
286-
finally:
300+
new_token = self.refresh()
301+
except Exception as e:
302+
# This happens on a thread, so we don't want to propagate the error.
303+
# Instead, if there is no new_token for any reason, we will disable async refresh below
304+
# But we will do it inside the lock.
305+
logger.warning(f'Tried to refresh token asynchronously, but failed: {e}')
306+
307+
with self._lock:
308+
if new_token is not None:
309+
self._token = new_token
310+
else:
311+
self._refresh_err = True
287312
self._is_refreshing = False
288313

289314
with self._lock:
315+
state = Refreshable._token_state(self._token, self._stale_duration)
316+
# The token may have been refreshed by another thread.
317+
if state == _TokenState.FRESH:
318+
return
290319
if not self._is_refreshing and not self._refresh_err:
291320
self._is_refreshing = True
292-
self._EXECUTOR.submit(_refresh_internal)
321+
Refreshable._get_executor().submit(_refresh_internal)
293322

294323
@abstractmethod
295324
def refresh(self) -> Token:

tests/test_refreshable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def refresh(self) -> Token:
2525

2626

2727
def fail() -> Token:
28-
raise Exception("Failed to refresh token")
28+
raise Exception("Simulated token refresh failure")
2929

3030

3131
def static_token(token: Token, wait: int = 0) -> Callable[[], Token]:

0 commit comments

Comments
 (0)