From 011890d33100a44a738fd0329e2e6e1cf0f930da Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Tue, 18 Feb 2025 16:05:06 +0100 Subject: [PATCH 1/4] WIP --- databricks/sdk/oauth.py | 104 +++++++++++++++- tests/test_refreshable.py | 244 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 342 insertions(+), 6 deletions(-) create mode 100644 tests/test_refreshable.py diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 6cac45afc..64825a474 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -11,8 +11,10 @@ from abc import abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta +from enum import Enum from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any, Dict, List, Optional +from concurrent.futures import ThreadPoolExecutor import requests import requests.auth @@ -186,22 +188,112 @@ def retrieve_token(client_id, except Exception as e: raise NotImplementedError(f"Not supported yet: {e}") +class _TokenState(Enum): + """ + tokenState represents the state of the token. Each token can be in one of + the following three states: + - FRESH: The token is valid. + - STALE: The token is valid but will expire soon. + - EXPIRED: The token has expired and cannot be used. + + Token state through time: + issue time expiry time + v v + | fresh | stale | expired -> time + | valid | + """ + FRESH = 1 # The token is valid. + STALE = 2 # The token is valid but will expire soon. + EXPIRED = 3 # The token has expired and cannot be used. + class Refreshable(TokenSource): + _executor = ThreadPoolExecutor(max_workers=10) + _default_stale_duration = 3 - def __init__(self, token=None): + def __init__(self, token=None, disable_async = True, stale_duration=timedelta(minutes=_default_stale_duration)): self._lock = threading.Lock() # to guard _token self._token = token + self._stale_duration = stale_duration + self._disable_async = disable_async + self._is_refreshing = False + self._refresh_err = False - def token(self) -> Token: + def token(self, blocking=False) -> Token: + if self._disable_async: + return self._blocking_token() + return self._async_token() + + def _async_token(self) -> Token: self._lock.acquire() - try: - if self._token and self._token.valid: + token_state = self._token_state() + token = self._token + self._lock.release() + match token_state: + case _TokenState.FRESH: + return token + case _TokenState.STALE: + self._trigger_async_refresh() + return token + case _: #Expired + return self._blocking_token() + + + def _token_state(self) -> _TokenState: + """ + Returns the state of the token. + """ + # Invalid tokens are considered expired. + if not self._token or not self._token.valid: + return _TokenState.EXPIRED + # Tokens without an expiry are considered always. + if not self._token.expiry: + return _TokenState.FRESH + lifespan = self._token.expiry - datetime.now() + if lifespan < timedelta(seconds=0): + return _TokenState.EXPIRED + if lifespan < self._stale_duration: + return _TokenState.STALE + return _TokenState.FRESH + + def _blocking_token(self) -> Token: + + # The lock is kept for the entire operation to ensure that only one + # refresh operation is running at a time. + with self._lock: + # This is important to recover from potential previous failed attempts + # to refresh the token asynchronously, see declaration of refresh_err for + # more information. + self._refresh_err = False + self._is_refreshing = False + + # It's possible that the token got refreshed (either by a _blocking_refresh or + # an _async_refresh call) while this particular call was waiting to acquire + # the lock. This check avoids refreshing the token again in such cases. + if self._token_state() != _TokenState.EXPIRED: return self._token + + # Refresh the token self._token = self.refresh() return self._token - finally: - self._lock.release() + + + def _trigger_async_refresh(self): + # Note: this is not thread safe. + # Only call it inside the lock. + def _refresh_internal(): + try: + self._token = self.refresh() + except Exception: + self._refresh_err = True + finally: + self._is_refreshing = False + # The lock is kept for the entire operation to ensure that only one + # refresh operation is running at a time. + with self._lock: + if not self._is_refreshing and not self._refresh_err: + self._is_refreshing = True + self._executor.submit(_refresh_internal) @abstractmethod def refresh(self) -> Token: diff --git a/tests/test_refreshable.py b/tests/test_refreshable.py new file mode 100644 index 000000000..be28d9531 --- /dev/null +++ b/tests/test_refreshable.py @@ -0,0 +1,244 @@ +import time +from time import sleep +from typing import Callable + +import pytest + +from datetime import datetime, timedelta + +from databricks.sdk.oauth import Refreshable, Token + + +class _MockRefreshable(Refreshable): + + def __init__(self, disable_async, token=None, stale_duration=timedelta(seconds=60), refresh_effect: Callable[[], Token]=None): + super().__init__(token, disable_async, stale_duration) + self._refresh_effect = refresh_effect + self._refresh_count = 0 + + def refresh(self) -> Token: + if self._refresh_effect: + self._token = self._refresh_effect() + self._refresh_count += 1 + return self._token + +def fail() -> Token: + raise Exception("Failed to refresh token") + +def static_token(token: Token, wait: int=0) -> Callable[[], Token]: + def f() -> Token: + time.sleep(wait) + return token + return f + +def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[],None]): + """ + Create a refresh function that blocks until unblock is called. + + Param: + token: the token that will be returned + + Returns: + A tuple containing the refresh function and the unblock function. + + """ + blocking = True + def refresh(): + while blocking: + sleep(0.1) + return token + def unblock(): + nonlocal blocking + blocking = False + return refresh, unblock + + +def test_disable_async_stale_does_not_refresh(): + stale_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=50), + ) + r = _MockRefreshable(token=stale_token, disable_async=True, refresh_effect=fail) + result = r.token() + assert r._refresh_count == 0 + assert result == stale_token + +def test_disable_async_no_token_does_refresh(): + token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=50), + ) + r = _MockRefreshable(token=None, disable_async=True, refresh_effect=static_token(token)) + result = r.token() + assert r._refresh_count == 1 + assert result == token + +def test_disable_async_no_expiration_does_not_refresh(): + non_expiring_token = Token( + access_token="access_token", + ) + r = _MockRefreshable(token=non_expiring_token, disable_async=True, refresh_effect=fail) + result = r.token() + assert r._refresh_count == 0 + assert result == non_expiring_token + +def test_disable_async_fresh_does_not_refresh(): + # Create a token that is already stale. If async is disabled, the token should not be refreshed. + token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=300), + ) + r = _MockRefreshable(token=token, disable_async=True, refresh_effect=fail) + result = r.token() + assert r._refresh_count == 0 + assert result == token + +def test_disable_async_expired_does_refresh(): + expired_token = Token( + access_token="access_token", + expiry=datetime.now() - timedelta(seconds=300), + ) + new_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=300), + ) + # Add one second to the refresh time to ensure that the call is blocking. + # If the call is not blocking, the wait time will ensure that the + # old token is returned. + r = _MockRefreshable(token=expired_token, disable_async=True, refresh_effect=static_token(new_token, wait=1)) + result = r.token() + assert r._refresh_count == 1 + assert result == new_token + +def test_expired_does_refresh(): + expired_token = Token( + access_token="access_token", + expiry=datetime.now() - timedelta(seconds=300), + ) + new_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=300), + ) + # Add one second to the refresh time to ensure that the call is blocking. + # If the call is not blocking, the wait time will ensure that the + # old token is returned. + r = _MockRefreshable(token=expired_token, disable_async=False, refresh_effect=static_token(new_token, wait=1)) + result = r.token() + assert r._refresh_count == 1 + assert result == new_token + +def test_stale_does_refresh_async(): + stale_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=50), + ) + new_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=300), + ) + # Add one second to the refresh to avoid race conditions. + # Without it, the new token may be returned in some cases. + refresh, unblock = blocking_refresh(new_token) + r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) + result = r.token() + # NOTE: Do not check for refresh count here, since the + assert result == stale_token + assert r._refresh_count == 0 + # Unblock the refresh and wait + unblock() + time.sleep(2) + # Call again and check that you get the new token + result = r.token() + assert result == new_token + # Ensure that all calls have completed + time.sleep(0.1) + assert r._refresh_count == 1 + + +def test_no_token_does_refresh(): + new_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=300), + ) + # Add one second to the refresh time to ensure that the call is blocking. + # If the call is not blocking, the wait time will ensure that the + # token is not returned. + r = _MockRefreshable(token=None, disable_async=False, refresh_effect=static_token(new_token, wait=1)) + result = r.token() + assert r._refresh_count == 1 + assert result == new_token + +def test_fresh_does_not_refresh(): + fresh_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=300), + ) + r = _MockRefreshable(token=fresh_token, disable_async=False, refresh_effect=fail) + result = r.token() + assert r._refresh_count == 0 + assert result == fresh_token + +def test_multiple_calls_dont_start_many_threads(): + stale_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=59), + ) + new_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=300), + ) + refresh, unblock = blocking_refresh(new_token) + r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) + # Call twice. The second call should not start a new thread. + result = r.token() + assert result == stale_token + result = r.token() + assert result == stale_token + unblock() + # Wait for the refresh to complete + time.sleep(1) + result = r.token() + # Check that only one refresh was called + assert r._refresh_count == 1 + assert result == new_token + +def test_async_failure_disables_async(): + stale_token = Token( + access_token="access_token", + expiry=datetime.now() + timedelta(seconds=59), + ) + new_token = Token( + access_token="new_token", + expiry=datetime.now() + timedelta(seconds=300), + ) + r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=fail) + # The call should fail and disable async refresh, + # but the exception will be catch inside the tread. + result = r.token() + assert result == stale_token + # Give time to the async refresh to fail + time.sleep(1) + assert r._refresh_err + # Now, the refresh should be blocking. + # Blocking refresh only happens for expired, not stale. + # Therefore, the next call should return the stale token. + r._refresh_effect = static_token(new_token, wait=1) + result = r.token() + assert result == stale_token + # Wait to be sure no async thread was started + time.sleep(1) + assert r._refresh_count == 0 + + # Inject an expired token. + expired_token = Token( + access_token="access_token", + expiry=datetime.now() - timedelta(seconds=300), + ) + r._token = expired_token + + # This should be blocking and return the new token. + result = r.token() + assert r._refresh_count == 1 + assert result == new_token + # The refresh error should be cleared. + assert not r._refresh_err From 51b9228e9534ce7757b4be0bc27546fd8fa771ae Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Wed, 19 Feb 2025 08:50:52 +0100 Subject: [PATCH 2/4] Cleanup --- databricks/sdk/oauth.py | 78 +++++++++++------------ tests/test_refreshable.py | 130 +++++++++++++++----------------------- 2 files changed, 88 insertions(+), 120 deletions(-) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 64825a474..e16e006af 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -9,12 +9,12 @@ import urllib.parse import webbrowser from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import datetime, timedelta from enum import Enum from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any, Dict, List, Optional -from concurrent.futures import ThreadPoolExecutor import requests import requests.auth @@ -188,19 +188,14 @@ def retrieve_token(client_id, except Exception as e: raise NotImplementedError(f"Not supported yet: {e}") + class _TokenState(Enum): """ - tokenState represents the state of the token. Each token can be in one of + Represents the state of a token. Each token can be in one of the following three states: - FRESH: The token is valid. - STALE: The token is valid but will expire soon. - EXPIRED: The token has expired and cannot be used. - - Token state through time: - issue time expiry time - v v - | fresh | stale | expired -> time - | valid | """ FRESH = 1 # The token is valid. STALE = 2 # The token is valid but will expire soon. @@ -208,47 +203,52 @@ class _TokenState(Enum): class Refreshable(TokenSource): - _executor = ThreadPoolExecutor(max_workers=10) - _default_stale_duration = 3 + """A token source that supports refreshing expired tokens.""" + + _EXECUTOR = ThreadPoolExecutor(max_workers=10) + _DEFAULT_STALE_DURATION = timedelta(minutes=3) - def __init__(self, token=None, disable_async = True, stale_duration=timedelta(minutes=_default_stale_duration)): - self._lock = threading.Lock() # to guard _token + def __init__(self, + token: Token = None, + disable_async: bool = True, + stale_duration: timedelta = _DEFAULT_STALE_DURATION): + self._lock = threading.Lock() self._token = token self._stale_duration = stale_duration self._disable_async = disable_async self._is_refreshing = False self._refresh_err = False - def token(self, blocking=False) -> Token: + def token(self) -> Token: + """Returns a valid token, blocking if async refresh is disabled.""" if self._disable_async: return self._blocking_token() return self._async_token() def _async_token(self) -> Token: - self._lock.acquire() - token_state = self._token_state() - token = self._token - self._lock.release() - match token_state: - case _TokenState.FRESH: - return token - case _TokenState.STALE: - self._trigger_async_refresh() - return token - case _: #Expired - return self._blocking_token() + """ + Returns a token. + If the token is stale, triggers an asynchronous refresh. + If the token is expired, refreshes it synchronously, blocking until the refresh is complete. + """ + with self._lock: + token_state = self._token_state() + token = self._token + if token_state == _TokenState.FRESH: + return token + if token_state == _TokenState.STALE: + self._trigger_async_refresh() + return token + return self._blocking_token() def _token_state(self) -> _TokenState: - """ - Returns the state of the token. - """ - # Invalid tokens are considered expired. + """Returns the current state of the token.""" if not self._token or not self._token.valid: return _TokenState.EXPIRED - # Tokens without an expiry are considered always. if not self._token.expiry: return _TokenState.FRESH + lifespan = self._token.expiry - datetime.now() if lifespan < timedelta(seconds=0): return _TokenState.EXPIRED @@ -257,13 +257,10 @@ def _token_state(self) -> _TokenState: return _TokenState.FRESH def _blocking_token(self) -> Token: - - # The lock is kept for the entire operation to ensure that only one - # refresh operation is running at a time. + """Returns a token, blocking if necessary to refresh it.""" with self._lock: # This is important to recover from potential previous failed attempts - # to refresh the token asynchronously, see declaration of refresh_err for - # more information. + # to refresh the token asynchronously. self._refresh_err = False self._is_refreshing = False @@ -273,13 +270,13 @@ def _blocking_token(self) -> Token: if self._token_state() != _TokenState.EXPIRED: return self._token - # Refresh the token self._token = self.refresh() return self._token - def _trigger_async_refresh(self): - # Note: this is not thread safe. + """Starts an asynchronous refresh if none is in progress.""" + + # Note: _refresh_internal function is not thread safe. # Only call it inside the lock. def _refresh_internal(): try: @@ -288,12 +285,11 @@ def _refresh_internal(): self._refresh_err = True finally: self._is_refreshing = False - # The lock is kept for the entire operation to ensure that only one - # refresh operation is running at a time. + with self._lock: if not self._is_refreshing and not self._refresh_err: self._is_refreshing = True - self._executor.submit(_refresh_internal) + self._EXECUTOR.submit(_refresh_internal) @abstractmethod def refresh(self) -> Token: diff --git a/tests/test_refreshable.py b/tests/test_refreshable.py index be28d9531..b011c754b 100644 --- a/tests/test_refreshable.py +++ b/tests/test_refreshable.py @@ -1,17 +1,18 @@ import time +from datetime import datetime, timedelta from time import sleep from typing import Callable -import pytest - -from datetime import datetime, timedelta - from databricks.sdk.oauth import Refreshable, Token class _MockRefreshable(Refreshable): - def __init__(self, disable_async, token=None, stale_duration=timedelta(seconds=60), refresh_effect: Callable[[], Token]=None): + def __init__(self, + disable_async, + token=None, + stale_duration=timedelta(seconds=60), + refresh_effect: Callable[[], Token] = None): super().__init__(token, disable_async, stale_duration) self._refresh_effect = refresh_effect self._refresh_count = 0 @@ -22,16 +23,21 @@ def refresh(self) -> Token: self._refresh_count += 1 return self._token + def fail() -> Token: raise Exception("Failed to refresh token") -def static_token(token: Token, wait: int=0) -> Callable[[], Token]: + +def static_token(token: Token, wait: int = 0) -> Callable[[], Token]: + def f() -> Token: time.sleep(wait) return token + return f -def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[],None]): + +def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[], None]): """ Create a refresh function that blocks until unblock is called. @@ -43,103 +49,87 @@ def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[],None]): """ blocking = True + def refresh(): while blocking: sleep(0.1) return token + def unblock(): nonlocal blocking blocking = False + return refresh, unblock def test_disable_async_stale_does_not_refresh(): - stale_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=50), - ) + stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), ) r = _MockRefreshable(token=stale_token, disable_async=True, refresh_effect=fail) result = r.token() assert r._refresh_count == 0 assert result == stale_token + def test_disable_async_no_token_does_refresh(): - token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=50), - ) + token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), ) r = _MockRefreshable(token=None, disable_async=True, refresh_effect=static_token(token)) result = r.token() assert r._refresh_count == 1 assert result == token + def test_disable_async_no_expiration_does_not_refresh(): - non_expiring_token = Token( - access_token="access_token", - ) + non_expiring_token = Token(access_token="access_token", ) r = _MockRefreshable(token=non_expiring_token, disable_async=True, refresh_effect=fail) result = r.token() assert r._refresh_count == 0 assert result == non_expiring_token + def test_disable_async_fresh_does_not_refresh(): # Create a token that is already stale. If async is disabled, the token should not be refreshed. - token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=300), - ) + token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) r = _MockRefreshable(token=token, disable_async=True, refresh_effect=fail) result = r.token() assert r._refresh_count == 0 assert result == token + def test_disable_async_expired_does_refresh(): - expired_token = Token( - access_token="access_token", - expiry=datetime.now() - timedelta(seconds=300), - ) - new_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=300), - ) + expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), ) + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) # Add one second to the refresh time to ensure that the call is blocking. # If the call is not blocking, the wait time will ensure that the # old token is returned. - r = _MockRefreshable(token=expired_token, disable_async=True, refresh_effect=static_token(new_token, wait=1)) + r = _MockRefreshable(token=expired_token, + disable_async=True, + refresh_effect=static_token(new_token, wait=1)) result = r.token() assert r._refresh_count == 1 assert result == new_token + def test_expired_does_refresh(): - expired_token = Token( - access_token="access_token", - expiry=datetime.now() - timedelta(seconds=300), - ) - new_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=300), - ) + expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), ) + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) # Add one second to the refresh time to ensure that the call is blocking. # If the call is not blocking, the wait time will ensure that the # old token is returned. - r = _MockRefreshable(token=expired_token, disable_async=False, refresh_effect=static_token(new_token, wait=1)) + r = _MockRefreshable(token=expired_token, + disable_async=False, + refresh_effect=static_token(new_token, wait=1)) result = r.token() assert r._refresh_count == 1 assert result == new_token + def test_stale_does_refresh_async(): - stale_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=50), - ) - new_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=300), - ) + stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), ) + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) # Add one second to the refresh to avoid race conditions. # Without it, the new token may be returned in some cases. refresh, unblock = blocking_refresh(new_token) - r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) + r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) result = r.token() # NOTE: Do not check for refresh count here, since the assert result == stale_token @@ -156,10 +146,7 @@ def test_stale_does_refresh_async(): def test_no_token_does_refresh(): - new_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=300), - ) + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) # Add one second to the refresh time to ensure that the call is blocking. # If the call is not blocking, the wait time will ensure that the # token is not returned. @@ -168,27 +155,20 @@ def test_no_token_does_refresh(): assert r._refresh_count == 1 assert result == new_token + def test_fresh_does_not_refresh(): - fresh_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=300), - ) + fresh_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) r = _MockRefreshable(token=fresh_token, disable_async=False, refresh_effect=fail) result = r.token() assert r._refresh_count == 0 assert result == fresh_token + def test_multiple_calls_dont_start_many_threads(): - stale_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=59), - ) - new_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=300), - ) + stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), ) + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) refresh, unblock = blocking_refresh(new_token) - r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) + r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) # Call twice. The second call should not start a new thread. result = r.token() assert result == stale_token @@ -202,16 +182,11 @@ def test_multiple_calls_dont_start_many_threads(): assert r._refresh_count == 1 assert result == new_token + def test_async_failure_disables_async(): - stale_token = Token( - access_token="access_token", - expiry=datetime.now() + timedelta(seconds=59), - ) - new_token = Token( - access_token="new_token", - expiry=datetime.now() + timedelta(seconds=300), - ) - r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=fail) + stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), ) + new_token = Token(access_token="new_token", expiry=datetime.now() + timedelta(seconds=300), ) + r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=fail) # The call should fail and disable async refresh, # but the exception will be catch inside the tread. result = r.token() @@ -230,10 +205,7 @@ def test_async_failure_disables_async(): assert r._refresh_count == 0 # Inject an expired token. - expired_token = Token( - access_token="access_token", - expiry=datetime.now() - timedelta(seconds=300), - ) + expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), ) r._token = expired_token # This should be blocking and return the new token. From 21c999fb271ad230105e79a694844a1e8f4bd6e8 Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Wed, 19 Feb 2025 11:27:22 +0100 Subject: [PATCH 3/4] Lazy load the executor --- databricks/sdk/oauth.py | 67 ++++++++++++++++++++++++++++----------- tests/test_refreshable.py | 2 +- 2 files changed, 49 insertions(+), 20 deletions(-) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index e16e006af..58d6ced75 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -205,17 +205,31 @@ class _TokenState(Enum): class Refreshable(TokenSource): """A token source that supports refreshing expired tokens.""" - _EXECUTOR = ThreadPoolExecutor(max_workers=10) + _EXECUTOR = None + _EXECUTOR_LOCK = threading.Lock() _DEFAULT_STALE_DURATION = timedelta(minutes=3) + @classmethod + def _get_executor(cls): + """Lazy initialization of the ThreadPoolExecutor.""" + if cls._EXECUTOR is None: + with cls._EXECUTOR_LOCK: + if cls._EXECUTOR is None: + # This thread pool has multiple workers because it is shared by all instances of Refreshable. + cls._EXECUTOR = ThreadPoolExecutor(max_workers=10) + return cls._EXECUTOR + def __init__(self, token: Token = None, disable_async: bool = True, stale_duration: timedelta = _DEFAULT_STALE_DURATION): - self._lock = threading.Lock() - self._token = token + # Config properties self._stale_duration = stale_duration self._disable_async = disable_async + # Lock + self._lock = threading.Lock() + # Non Thread safe properties. Protected by the lock above. + self._token = token self._is_refreshing = False self._refresh_err = False @@ -232,33 +246,37 @@ def _async_token(self) -> Token: If the token is expired, refreshes it synchronously, blocking until the refresh is complete. """ with self._lock: - token_state = self._token_state() + state = Refreshable._token_state(self._token, self._stale_duration) token = self._token - if token_state == _TokenState.FRESH: + if state == _TokenState.FRESH: return token - if token_state == _TokenState.STALE: + if state == _TokenState.STALE: self._trigger_async_refresh() return token return self._blocking_token() - def _token_state(self) -> _TokenState: + # This is a class method and we pass the token to avoid + # concurrency issues and deadlocks. + @classmethod + def _token_state(cls, token: Token, stale_duration: timedelta) -> _TokenState: """Returns the current state of the token.""" - if not self._token or not self._token.valid: + if not token or not token.valid: return _TokenState.EXPIRED - if not self._token.expiry: + if not token.expiry: return _TokenState.FRESH - lifespan = self._token.expiry - datetime.now() + lifespan = token.expiry - datetime.now() if lifespan < timedelta(seconds=0): return _TokenState.EXPIRED - if lifespan < self._stale_duration: + if lifespan < stale_duration: return _TokenState.STALE return _TokenState.FRESH def _blocking_token(self) -> Token: """Returns a token, blocking if necessary to refresh it.""" with self._lock: + state = Refreshable._token_state(self._token, self._stale_duration) # This is important to recover from potential previous failed attempts # to refresh the token asynchronously. self._refresh_err = False @@ -267,7 +285,7 @@ def _blocking_token(self) -> Token: # It's possible that the token got refreshed (either by a _blocking_refresh or # an _async_refresh call) while this particular call was waiting to acquire # the lock. This check avoids refreshing the token again in such cases. - if self._token_state() != _TokenState.EXPIRED: + if state != _TokenState.EXPIRED: return self._token self._token = self.refresh() @@ -276,20 +294,31 @@ def _blocking_token(self) -> Token: def _trigger_async_refresh(self): """Starts an asynchronous refresh if none is in progress.""" - # Note: _refresh_internal function is not thread safe. - # Only call it inside the lock. def _refresh_internal(): + new_token: Token = None try: - self._token = self.refresh() - except Exception: - self._refresh_err = True - finally: + new_token = self.refresh() + except Exception as e: + # This happens on a thread, so we don't want to propagate the error. + # Instead, if there is no new_token for any reason, we will disable async refresh below + # But we will do it inside the lock. + logger.warning(f'Tried to refresh token asynchronously, but failed: {e}') + + with self._lock: + if new_token is not None: + self._token = new_token + else: + self._refresh_err = True self._is_refreshing = False with self._lock: + state = Refreshable._token_state(self._token, self._stale_duration) + # The token may have been refreshed by another thread. + if state == _TokenState.FRESH: + return if not self._is_refreshing and not self._refresh_err: self._is_refreshing = True - self._EXECUTOR.submit(_refresh_internal) + Refreshable._get_executor().submit(_refresh_internal) @abstractmethod def refresh(self) -> Token: diff --git a/tests/test_refreshable.py b/tests/test_refreshable.py index b011c754b..7265026e8 100644 --- a/tests/test_refreshable.py +++ b/tests/test_refreshable.py @@ -25,7 +25,7 @@ def refresh(self) -> Token: def fail() -> Token: - raise Exception("Failed to refresh token") + raise Exception("Simulated token refresh failure") def static_token(token: Token, wait: int = 0) -> Callable[[], Token]: From d3ffe7968de23f88dcb07c774eabad12b1baa05e Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Fri, 21 Feb 2025 14:54:03 +0100 Subject: [PATCH 4/4] Mega block --- databricks/sdk/oauth.py | 72 +++++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 58d6ced75..c9a9d15c6 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -228,16 +228,19 @@ def __init__(self, self._disable_async = disable_async # Lock self._lock = threading.Lock() - # Non Thread safe properties. Protected by the lock above. + # Non Thread safe properties. They should be accessed only when protected by the lock above. self._token = token self._is_refreshing = False self._refresh_err = False + # This is the main entry point for the Token. Do not access the token + # using any of the internal functions. def token(self) -> Token: """Returns a valid token, blocking if async refresh is disabled.""" - if self._disable_async: - return self._blocking_token() - return self._async_token() + with self._lock: + if self._disable_async: + return self._blocking_token() + return self._async_token() def _async_token(self) -> Token: """ @@ -245,9 +248,8 @@ def _async_token(self) -> Token: If the token is stale, triggers an asynchronous refresh. If the token is expired, refreshes it synchronously, blocking until the refresh is complete. """ - with self._lock: - state = Refreshable._token_state(self._token, self._stale_duration) - token = self._token + state = self._token_state() + token = self._token if state == _TokenState.FRESH: return token @@ -256,41 +258,37 @@ def _async_token(self) -> Token: return token return self._blocking_token() - # This is a class method and we pass the token to avoid - # concurrency issues and deadlocks. - @classmethod - def _token_state(cls, token: Token, stale_duration: timedelta) -> _TokenState: + def _token_state(self) -> _TokenState: """Returns the current state of the token.""" - if not token or not token.valid: + if not self._token or not self._token.valid: return _TokenState.EXPIRED - if not token.expiry: + if not self._token.expiry: return _TokenState.FRESH - lifespan = token.expiry - datetime.now() + lifespan = self._token.expiry - datetime.now() if lifespan < timedelta(seconds=0): return _TokenState.EXPIRED - if lifespan < stale_duration: + if lifespan < self._stale_duration: return _TokenState.STALE return _TokenState.FRESH def _blocking_token(self) -> Token: """Returns a token, blocking if necessary to refresh it.""" - with self._lock: - state = Refreshable._token_state(self._token, self._stale_duration) - # This is important to recover from potential previous failed attempts - # to refresh the token asynchronously. - self._refresh_err = False - self._is_refreshing = False - - # It's possible that the token got refreshed (either by a _blocking_refresh or - # an _async_refresh call) while this particular call was waiting to acquire - # the lock. This check avoids refreshing the token again in such cases. - if state != _TokenState.EXPIRED: - return self._token - - self._token = self.refresh() + state = self._token_state() + # This is important to recover from potential previous failed attempts + # to refresh the token asynchronously. + self._refresh_err = False + self._is_refreshing = False + + # It's possible that the token got refreshed (either by a _blocking_refresh or + # an _async_refresh call) while this particular call was waiting to acquire + # the lock. This check avoids refreshing the token again in such cases. + if state != _TokenState.EXPIRED: return self._token + self._token = self.refresh() + return self._token + def _trigger_async_refresh(self): """Starts an asynchronous refresh if none is in progress.""" @@ -311,14 +309,12 @@ def _refresh_internal(): self._refresh_err = True self._is_refreshing = False - with self._lock: - state = Refreshable._token_state(self._token, self._stale_duration) - # The token may have been refreshed by another thread. - if state == _TokenState.FRESH: - return - if not self._is_refreshing and not self._refresh_err: - self._is_refreshing = True - Refreshable._get_executor().submit(_refresh_internal) + # The token may have been refreshed by another thread. + if self._token_state() == _TokenState.FRESH: + return + if not self._is_refreshing and not self._refresh_err: + self._is_refreshing = True + Refreshable._get_executor().submit(_refresh_internal) @abstractmethod def refresh(self) -> Token: @@ -412,7 +408,7 @@ def __init__(self, super().__init__(token) def as_dict(self) -> dict: - return {'token': self._token.as_dict()} + return {'token': self.token().as_dict()} @staticmethod def from_dict(raw: dict,