Skip to content

Commit 11c42a2

Browse files
committed
Lazy load the executor
1 parent 51b9228 commit 11c42a2

File tree

2 files changed

+47
-23
lines changed

2 files changed

+47
-23
lines changed

databricks/sdk/oauth.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -205,17 +205,31 @@ class _TokenState(Enum):
205205
class Refreshable(TokenSource):
206206
"""A token source that supports refreshing expired tokens."""
207207

208-
_EXECUTOR = ThreadPoolExecutor(max_workers=10)
208+
_EXECUTOR = None
209+
_EXECUTOR_LOCK = threading.Lock()
209210
_DEFAULT_STALE_DURATION = timedelta(minutes=3)
210211

212+
@classmethod
213+
def _get_executor(cls):
214+
"""Lazy initialization of the ThreadPoolExecutor."""
215+
if cls._EXECUTOR is None:
216+
with cls._EXECUTOR_LOCK:
217+
if cls._EXECUTOR is None:
218+
# This thread pool has multiple workers because it is shared by all instances of Refreshable.
219+
cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
220+
return cls._EXECUTOR
221+
211222
def __init__(self,
212223
token: Token = None,
213224
disable_async: bool = True,
214225
stale_duration: timedelta = _DEFAULT_STALE_DURATION):
215-
self._lock = threading.Lock()
216-
self._token = token
226+
# Config properties
217227
self._stale_duration = stale_duration
218228
self._disable_async = disable_async
229+
# Lock
230+
self._lock = threading.Lock()
231+
# Non Thread safe properties. Protected by the lock above.
232+
self._token = token
219233
self._is_refreshing = False
220234
self._refresh_err = False
221235

@@ -231,8 +245,9 @@ def _async_token(self) -> Token:
231245
If the token is stale, triggers an asynchronous refresh.
232246
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
233247
"""
248+
token_state = self._token_state()
249+
234250
with self._lock:
235-
token_state = self._token_state()
236251
token = self._token
237252

238253
if token_state == _TokenState.FRESH:
@@ -244,20 +259,22 @@ def _async_token(self) -> Token:
244259

245260
def _token_state(self) -> _TokenState:
246261
"""Returns the current state of the token."""
247-
if not self._token or not self._token.valid:
248-
return _TokenState.EXPIRED
249-
if not self._token.expiry:
262+
with self._lock:
263+
if not self._token or not self._token.valid:
264+
return _TokenState.EXPIRED
265+
if not self._token.expiry:
266+
return _TokenState.FRESH
267+
268+
lifespan = self._token.expiry - datetime.now()
269+
if lifespan < timedelta(seconds=0):
270+
return _TokenState.EXPIRED
271+
if lifespan < self._stale_duration:
272+
return _TokenState.STALE
250273
return _TokenState.FRESH
251274

252-
lifespan = self._token.expiry - datetime.now()
253-
if lifespan < timedelta(seconds=0):
254-
return _TokenState.EXPIRED
255-
if lifespan < self._stale_duration:
256-
return _TokenState.STALE
257-
return _TokenState.FRESH
258-
259275
def _blocking_token(self) -> Token:
260276
"""Returns a token, blocking if necessary to refresh it."""
277+
state = self._token_state()
261278
with self._lock:
262279
# This is important to recover from potential previous failed attempts
263280
# to refresh the token asynchronously.
@@ -267,7 +284,7 @@ def _blocking_token(self) -> Token:
267284
# It's possible that the token got refreshed (either by a _blocking_refresh or
268285
# an _async_refresh call) while this particular call was waiting to acquire
269286
# the lock. This check avoids refreshing the token again in such cases.
270-
if self._token_state() != _TokenState.EXPIRED:
287+
if state != _TokenState.EXPIRED:
271288
return self._token
272289

273290
self._token = self.refresh()
@@ -276,20 +293,27 @@ def _blocking_token(self) -> Token:
276293
def _trigger_async_refresh(self):
277294
"""Starts an asynchronous refresh if none is in progress."""
278295

279-
# Note: _refresh_internal function is not thread safe.
280-
# Only call it inside the lock.
281296
def _refresh_internal():
297+
new_token: Token = None
282298
try:
283-
self._token = self.refresh()
284-
except Exception:
285-
self._refresh_err = True
286-
finally:
299+
new_token = self.refresh()
300+
except Exception as e:
301+
# This happens on a thread, so we don't want to propagate the error.
302+
# Instead, if there is no new_token for any reason, we will disable async refresh below
303+
# But we will do it inside the lock.
304+
logger.warning(f'Tried to refresh token asynchronously, but failed: {e}')
305+
306+
with self._lock:
307+
if new_token is not None:
308+
self._token = new_token
309+
else:
310+
self._refresh_err = True
287311
self._is_refreshing = False
288312

289313
with self._lock:
290314
if not self._is_refreshing and not self._refresh_err:
291315
self._is_refreshing = True
292-
self._EXECUTOR.submit(_refresh_internal)
316+
Refreshable._get_executor().submit(_refresh_internal)
293317

294318
@abstractmethod
295319
def refresh(self) -> Token:

tests/test_refreshable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def refresh(self) -> Token:
2525

2626

2727
def fail() -> Token:
28-
raise Exception("Failed to refresh token")
28+
raise Exception("Simulated token refresh failure")
2929

3030

3131
def static_token(token: Token, wait: int = 0) -> Callable[[], Token]:

0 commit comments

Comments
 (0)