-
Notifications
You must be signed in to change notification settings - Fork 178
[Internal] Implement async token refresh #893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,8 +9,10 @@ | |
| import urllib.parse | ||
| import webbrowser | ||
| from abc import abstractmethod | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from dataclasses import dataclass | ||
| from datetime import datetime, timedelta | ||
| from enum import Enum | ||
| from http.server import BaseHTTPRequestHandler, HTTPServer | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
|
|
@@ -187,21 +189,136 @@ def retrieve_token(client_id, | |
| raise NotImplementedError(f"Not supported yet: {e}") | ||
|
|
||
|
|
||
| class _TokenState(Enum): | ||
| """ | ||
| Represents the state of a token. Each token can be in one of | ||
| the following three states: | ||
| - FRESH: The token is valid. | ||
| - STALE: The token is valid but will expire soon. | ||
| - EXPIRED: The token has expired and cannot be used. | ||
| """ | ||
| FRESH = 1 # The token is valid. | ||
| STALE = 2 # The token is valid but will expire soon. | ||
| EXPIRED = 3 # The token has expired and cannot be used. | ||
|
|
||
|
|
||
| class Refreshable(TokenSource): | ||
| """A token source that supports refreshing expired tokens.""" | ||
|
|
||
| _EXECUTOR = None | ||
| _EXECUTOR_LOCK = threading.Lock() | ||
| _DEFAULT_STALE_DURATION = timedelta(minutes=3) | ||
|
|
||
| @classmethod | ||
| def _get_executor(cls): | ||
| """Lazy initialization of the ThreadPoolExecutor.""" | ||
| if cls._EXECUTOR is None: | ||
| with cls._EXECUTOR_LOCK: | ||
| if cls._EXECUTOR is None: | ||
| # This thread pool has multiple workers because it is shared by all instances of Refreshable. | ||
| cls._EXECUTOR = ThreadPoolExecutor(max_workers=10) | ||
| return cls._EXECUTOR | ||
|
|
||
| def __init__(self, token=None): | ||
| self._lock = threading.Lock() # to guard _token | ||
| def __init__(self, | ||
| token: Token = None, | ||
| disable_async: bool = True, | ||
| stale_duration: timedelta = _DEFAULT_STALE_DURATION): | ||
| # Config properties | ||
| self._stale_duration = stale_duration | ||
| self._disable_async = disable_async | ||
| # Lock | ||
| self._lock = threading.Lock() | ||
| # Non Thread safe properties. Protected by the lock above. | ||
| self._token = token | ||
| self._is_refreshing = False | ||
| self._refresh_err = False | ||
hectorcast-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def token(self) -> Token: | ||
| self._lock.acquire() | ||
| try: | ||
| if self._token and self._token.valid: | ||
| """Returns a valid token, blocking if async refresh is disabled.""" | ||
| if self._disable_async: | ||
| return self._blocking_token() | ||
| return self._async_token() | ||
|
|
||
| def _async_token(self) -> Token: | ||
| """ | ||
| Returns a token. | ||
| If the token is stale, triggers an asynchronous refresh. | ||
| If the token is expired, refreshes it synchronously, blocking until the refresh is complete. | ||
| """ | ||
| with self._lock: | ||
| state = Refreshable._token_state(self._token, self._stale_duration) | ||
| token = self._token | ||
|
|
||
| if state == _TokenState.FRESH: | ||
| return token | ||
| if state == _TokenState.STALE: | ||
| self._trigger_async_refresh() | ||
| return token | ||
| return self._blocking_token() | ||
|
|
||
| # This is a class method and we pass the token to avoid | ||
| # concurrency issues and deadlocks. | ||
| @classmethod | ||
| def _token_state(cls, token: Token, stale_duration: timedelta) -> _TokenState: | ||
| """Returns the current state of the token.""" | ||
| if not token or not token.valid: | ||
| return _TokenState.EXPIRED | ||
| if not token.expiry: | ||
| return _TokenState.FRESH | ||
|
|
||
| lifespan = token.expiry - datetime.now() | ||
| if lifespan < timedelta(seconds=0): | ||
| return _TokenState.EXPIRED | ||
| if lifespan < stale_duration: | ||
| return _TokenState.STALE | ||
| return _TokenState.FRESH | ||
|
|
||
| def _blocking_token(self) -> Token: | ||
| """Returns a token, blocking if necessary to refresh it.""" | ||
| with self._lock: | ||
| state = Refreshable._token_state(self._token, self._stale_duration) | ||
| # This is important to recover from potential previous failed attempts | ||
| # to refresh the token asynchronously. | ||
| self._refresh_err = False | ||
| self._is_refreshing = False | ||
|
|
||
| # It's possible that the token got refreshed (either by a _blocking_refresh or | ||
| # an _async_refresh call) while this particular call was waiting to acquire | ||
| # the lock. This check avoids refreshing the token again in such cases. | ||
| if state != _TokenState.EXPIRED: | ||
| return self._token | ||
|
|
||
| self._token = self.refresh() | ||
| return self._token | ||
| finally: | ||
| self._lock.release() | ||
|
|
||
| def _trigger_async_refresh(self): | ||
| """Starts an asynchronous refresh if none is in progress.""" | ||
|
|
||
| def _refresh_internal(): | ||
| new_token: Token = None | ||
| try: | ||
| new_token = self.refresh() | ||
| except Exception as e: | ||
| # This happens on a thread, so we don't want to propagate the error. | ||
| # Instead, if there is no new_token for any reason, we will disable async refresh below | ||
| # But we will do it inside the lock. | ||
| logger.warning(f'Tried to refresh token asynchronously, but failed: {e}') | ||
|
|
||
| with self._lock: | ||
| if new_token is not None: | ||
| self._token = new_token | ||
| else: | ||
| self._refresh_err = True | ||
| self._is_refreshing = False | ||
|
|
||
| with self._lock: | ||
| state = Refreshable._token_state(self._token, self._stale_duration) | ||
| # The token may have been refreshed by another thread. | ||
| if state == _TokenState.FRESH: | ||
| return | ||
renaudhartert-db marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not self._is_refreshing and not self._refresh_err: | ||
|
||
| self._is_refreshing = True | ||
| Refreshable._get_executor().submit(_refresh_internal) | ||
|
|
||
| @abstractmethod | ||
| def refresh(self) -> Token: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,216 @@ | ||
| import time | ||
| from datetime import datetime, timedelta | ||
| from time import sleep | ||
| from typing import Callable | ||
|
|
||
| from databricks.sdk.oauth import Refreshable, Token | ||
|
|
||
|
|
||
| class _MockRefreshable(Refreshable): | ||
|
|
||
| def __init__(self, | ||
| disable_async, | ||
| token=None, | ||
| stale_duration=timedelta(seconds=60), | ||
| refresh_effect: Callable[[], Token] = None): | ||
| super().__init__(token, disable_async, stale_duration) | ||
| self._refresh_effect = refresh_effect | ||
| self._refresh_count = 0 | ||
|
|
||
| def refresh(self) -> Token: | ||
| if self._refresh_effect: | ||
| self._token = self._refresh_effect() | ||
| self._refresh_count += 1 | ||
| return self._token | ||
|
|
||
|
|
||
| def fail() -> Token: | ||
| raise Exception("Simulated token refresh failure") | ||
|
|
||
|
|
||
| def static_token(token: Token, wait: int = 0) -> Callable[[], Token]: | ||
|
|
||
| def f() -> Token: | ||
| time.sleep(wait) | ||
| return token | ||
|
|
||
| return f | ||
|
|
||
|
|
||
| def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[], None]): | ||
| """ | ||
| Create a refresh function that blocks until unblock is called. | ||
|
|
||
| Param: | ||
| token: the token that will be returned | ||
|
|
||
| Returns: | ||
| A tuple containing the refresh function and the unblock function. | ||
|
|
||
| """ | ||
| blocking = True | ||
|
|
||
| def refresh(): | ||
| while blocking: | ||
| sleep(0.1) | ||
| return token | ||
|
|
||
| def unblock(): | ||
| nonlocal blocking | ||
| blocking = False | ||
|
|
||
| return refresh, unblock | ||
|
|
||
|
|
||
| def test_disable_async_stale_does_not_refresh(): | ||
| stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), ) | ||
| r = _MockRefreshable(token=stale_token, disable_async=True, refresh_effect=fail) | ||
| result = r.token() | ||
| assert r._refresh_count == 0 | ||
| assert result == stale_token | ||
|
|
||
|
|
||
| def test_disable_async_no_token_does_refresh(): | ||
| token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), ) | ||
| r = _MockRefreshable(token=None, disable_async=True, refresh_effect=static_token(token)) | ||
| result = r.token() | ||
| assert r._refresh_count == 1 | ||
| assert result == token | ||
|
|
||
|
|
||
| def test_disable_async_no_expiration_does_not_refresh(): | ||
| non_expiring_token = Token(access_token="access_token", ) | ||
| r = _MockRefreshable(token=non_expiring_token, disable_async=True, refresh_effect=fail) | ||
| result = r.token() | ||
| assert r._refresh_count == 0 | ||
| assert result == non_expiring_token | ||
|
|
||
|
|
||
| def test_disable_async_fresh_does_not_refresh(): | ||
| # Create a token that is already stale. If async is disabled, the token should not be refreshed. | ||
| token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) | ||
| r = _MockRefreshable(token=token, disable_async=True, refresh_effect=fail) | ||
| result = r.token() | ||
| assert r._refresh_count == 0 | ||
| assert result == token | ||
|
|
||
|
|
||
| def test_disable_async_expired_does_refresh(): | ||
| expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), ) | ||
| new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) | ||
| # Add one second to the refresh time to ensure that the call is blocking. | ||
| # If the call is not blocking, the wait time will ensure that the | ||
| # old token is returned. | ||
| r = _MockRefreshable(token=expired_token, | ||
| disable_async=True, | ||
| refresh_effect=static_token(new_token, wait=1)) | ||
| result = r.token() | ||
| assert r._refresh_count == 1 | ||
| assert result == new_token | ||
|
|
||
|
|
||
| def test_expired_does_refresh(): | ||
| expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), ) | ||
| new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) | ||
| # Add one second to the refresh time to ensure that the call is blocking. | ||
| # If the call is not blocking, the wait time will ensure that the | ||
| # old token is returned. | ||
| r = _MockRefreshable(token=expired_token, | ||
| disable_async=False, | ||
| refresh_effect=static_token(new_token, wait=1)) | ||
| result = r.token() | ||
| assert r._refresh_count == 1 | ||
| assert result == new_token | ||
|
|
||
|
|
||
| def test_stale_does_refresh_async(): | ||
| stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), ) | ||
| new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) | ||
| # Add one second to the refresh to avoid race conditions. | ||
| # Without it, the new token may be returned in some cases. | ||
| refresh, unblock = blocking_refresh(new_token) | ||
| r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) | ||
| result = r.token() | ||
| # NOTE: Do not check for refresh count here, since the | ||
| assert result == stale_token | ||
| assert r._refresh_count == 0 | ||
| # Unblock the refresh and wait | ||
| unblock() | ||
| time.sleep(2) | ||
| # Call again and check that you get the new token | ||
| result = r.token() | ||
| assert result == new_token | ||
| # Ensure that all calls have completed | ||
| time.sleep(0.1) | ||
| assert r._refresh_count == 1 | ||
|
|
||
|
|
||
| def test_no_token_does_refresh(): | ||
| new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) | ||
| # Add one second to the refresh time to ensure that the call is blocking. | ||
| # If the call is not blocking, the wait time will ensure that the | ||
| # token is not returned. | ||
| r = _MockRefreshable(token=None, disable_async=False, refresh_effect=static_token(new_token, wait=1)) | ||
| result = r.token() | ||
| assert r._refresh_count == 1 | ||
| assert result == new_token | ||
|
|
||
|
|
||
| def test_fresh_does_not_refresh(): | ||
| fresh_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) | ||
| r = _MockRefreshable(token=fresh_token, disable_async=False, refresh_effect=fail) | ||
| result = r.token() | ||
| assert r._refresh_count == 0 | ||
| assert result == fresh_token | ||
|
|
||
|
|
||
| def test_multiple_calls_dont_start_many_threads(): | ||
| stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), ) | ||
| new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) | ||
| refresh, unblock = blocking_refresh(new_token) | ||
| r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) | ||
| # Call twice. The second call should not start a new thread. | ||
| result = r.token() | ||
| assert result == stale_token | ||
| result = r.token() | ||
| assert result == stale_token | ||
| unblock() | ||
| # Wait for the refresh to complete | ||
| time.sleep(1) | ||
| result = r.token() | ||
| # Check that only one refresh was called | ||
| assert r._refresh_count == 1 | ||
| assert result == new_token | ||
|
|
||
|
|
||
| def test_async_failure_disables_async(): | ||
| stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), ) | ||
| new_token = Token(access_token="new_token", expiry=datetime.now() + timedelta(seconds=300), ) | ||
| r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=fail) | ||
| # The call should fail and disable async refresh, | ||
| # but the exception will be catch inside the tread. | ||
| result = r.token() | ||
| assert result == stale_token | ||
| # Give time to the async refresh to fail | ||
| time.sleep(1) | ||
| assert r._refresh_err | ||
| # Now, the refresh should be blocking. | ||
| # Blocking refresh only happens for expired, not stale. | ||
| # Therefore, the next call should return the stale token. | ||
| r._refresh_effect = static_token(new_token, wait=1) | ||
| result = r.token() | ||
| assert result == stale_token | ||
| # Wait to be sure no async thread was started | ||
| time.sleep(1) | ||
| assert r._refresh_count == 0 | ||
|
|
||
| # Inject an expired token. | ||
| expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), ) | ||
| r._token = expired_token | ||
|
|
||
| # This should be blocking and return the new token. | ||
| result = r.token() | ||
| assert r._refresh_count == 1 | ||
| assert result == new_token | ||
| # The refresh error should be cleared. | ||
| assert not r._refresh_err |
Uh oh!
There was an error while loading. Please reload this page.