Skip to content

Commit cdcbf27

Browse files
committed
Lazy load the executor
1 parent 51b9228 commit cdcbf27

File tree

2 files changed

+51
-24
lines changed

2 files changed

+51
-24
lines changed

databricks/sdk/oauth.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import base64
23
import functools
34
import hashlib
@@ -205,17 +206,31 @@ class _TokenState(Enum):
205206
class Refreshable(TokenSource):
206207
"""A token source that supports refreshing expired tokens."""
207208

208-
_EXECUTOR = ThreadPoolExecutor(max_workers=10)
209+
_EXECUTOR = None
210+
_EXECUTOR_LOCK = threading.Lock()
209211
_DEFAULT_STALE_DURATION = timedelta(minutes=3)
210212

213+
@classmethod
214+
def _get_executor(cls):
215+
"""Lazy initialization of the ThreadPoolExecutor."""
216+
if cls._EXECUTOR is None:
217+
with cls._EXECUTOR_LOCK:
218+
if cls._EXECUTOR is None:
219+
# This thread pool has multiple workers because it is shared by all instances of Refreshable.
220+
cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
221+
return cls._EXECUTOR
222+
211223
def __init__(self,
212224
token: Token = None,
213225
disable_async: bool = True,
214226
stale_duration: timedelta = _DEFAULT_STALE_DURATION):
215-
self._lock = threading.Lock()
216-
self._token = token
227+
# Config properties
217228
self._stale_duration = stale_duration
218229
self._disable_async = disable_async
230+
# Lock
231+
self._lock = threading.Lock()
232+
# Non Thread safe properties. Protected by the lock above.
233+
self._token = token
219234
self._is_refreshing = False
220235
self._refresh_err = False
221236

@@ -231,8 +246,9 @@ def _async_token(self) -> Token:
231246
If the token is stale, triggers an asynchronous refresh.
232247
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
233248
"""
234-
with self._lock:
235-
token_state = self._token_state()
249+
token_state = self._token_state()
250+
251+
with self._lock:
236252
token = self._token
237253

238254
if token_state == _TokenState.FRESH:
@@ -244,20 +260,22 @@ def _async_token(self) -> Token:
244260

245261
def _token_state(self) -> _TokenState:
246262
"""Returns the current state of the token."""
247-
if not self._token or not self._token.valid:
248-
return _TokenState.EXPIRED
249-
if not self._token.expiry:
263+
with self._lock:
264+
if not self._token or not self._token.valid:
265+
return _TokenState.EXPIRED
266+
if not self._token.expiry:
267+
return _TokenState.FRESH
268+
269+
lifespan = self._token.expiry - datetime.now()
270+
if lifespan < timedelta(seconds=0):
271+
return _TokenState.EXPIRED
272+
if lifespan < self._stale_duration:
273+
return _TokenState.STALE
250274
return _TokenState.FRESH
251275

252-
lifespan = self._token.expiry - datetime.now()
253-
if lifespan < timedelta(seconds=0):
254-
return _TokenState.EXPIRED
255-
if lifespan < self._stale_duration:
256-
return _TokenState.STALE
257-
return _TokenState.FRESH
258-
259276
def _blocking_token(self) -> Token:
260277
"""Returns a token, blocking if necessary to refresh it."""
278+
state = self._token_state()
261279
with self._lock:
262280
# This is important to recover from potential previous failed attempts
263281
# to refresh the token asynchronously.
@@ -267,7 +285,7 @@ def _blocking_token(self) -> Token:
267285
# It's possible that the token got refreshed (either by a _blocking_refresh or
268286
# an _async_refresh call) while this particular call was waiting to acquire
269287
# the lock. This check avoids refreshing the token again in such cases.
270-
if self._token_state() != _TokenState.EXPIRED:
288+
if state != _TokenState.EXPIRED:
271289
return self._token
272290

273291
self._token = self.refresh()
@@ -276,20 +294,29 @@ def _blocking_token(self) -> Token:
276294
def _trigger_async_refresh(self):
277295
"""Starts an asynchronous refresh if none is in progress."""
278296

279-
# Note: _refresh_internal function is not thread safe.
280-
# Only call it inside the lock.
281297
def _refresh_internal():
298+
new_token: Token = None
282299
try:
283-
self._token = self.refresh()
284-
except Exception:
285-
self._refresh_err = True
286-
finally:
300+
new_token = self.refresh()
301+
except Exception as e:
302+
# This happens on a thread, so we don't want to propagate the error.
303+
# Instead, if there is no new_token for any reason, we will disable async refresh below
304+
# But we will do it inside the lock.
305+
logger.warning(f'Tried to refresh token asynchronously, but failed: {e}')
306+
pass
307+
308+
with self._lock:
309+
if new_token is not None:
310+
self._token = new_token
311+
else:
312+
self._refresh_err = True
287313
self._is_refreshing = False
314+
288315

289316
with self._lock:
290317
if not self._is_refreshing and not self._refresh_err:
291318
self._is_refreshing = True
292-
self._EXECUTOR.submit(_refresh_internal)
319+
Refreshable._get_executor().submit(_refresh_internal)
293320

294321
@abstractmethod
295322
def refresh(self) -> Token:

tests/test_refreshable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def refresh(self) -> Token:
2525

2626

2727
def fail() -> Token:
28-
raise Exception("Failed to refresh token")
28+
raise Exception("Simulated token refresh failure")
2929

3030

3131
def static_token(token: Token, wait: int = 0) -> Callable[[], Token]:

0 commit comments

Comments
 (0)