Skip to content

Commit e550ca1

Browse files
[Internal] Implement async token refresh (#893)
## What changes are proposed in this pull request? This PR is a step towards enabling asynchronous refreshes of data plane tokens. This PR updates the existing `Refreshable` abstract token class to support async token refresh. Note: async refreshes are disabled at the moment and will be enabled in a follow-up PR. ## How is this tested? Added unit tests. ## Changelog The changelog entry will be added when the feature is enabled. NO_CHANGELOG=true
1 parent 3d3752a commit e550ca1

File tree

2 files changed

+339
-10
lines changed

2 files changed

+339
-10
lines changed

databricks/sdk/oauth.py

Lines changed: 123 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import urllib.parse
1010
import webbrowser
1111
from abc import abstractmethod
12+
from concurrent.futures import ThreadPoolExecutor
1213
from dataclasses import dataclass
1314
from datetime import datetime, timedelta
15+
from enum import Enum
1416
from http.server import BaseHTTPRequestHandler, HTTPServer
1517
from typing import Any, Dict, List, Optional
1618

@@ -187,21 +189,132 @@ def retrieve_token(client_id,
187189
raise NotImplementedError(f"Not supported yet: {e}")
188190

189191

192+
class _TokenState(Enum):
193+
"""
194+
Represents the state of a token. Each token can be in one of
195+
the following three states:
196+
- FRESH: The token is valid.
197+
- STALE: The token is valid but will expire soon.
198+
- EXPIRED: The token has expired and cannot be used.
199+
"""
200+
FRESH = 1 # The token is valid.
201+
STALE = 2 # The token is valid but will expire soon.
202+
EXPIRED = 3 # The token has expired and cannot be used.
203+
204+
190205
class Refreshable(TokenSource):
206+
"""A token source that supports refreshing expired tokens."""
207+
208+
_EXECUTOR = None
209+
_EXECUTOR_LOCK = threading.Lock()
210+
_DEFAULT_STALE_DURATION = timedelta(minutes=3)
211+
212+
@classmethod
213+
def _get_executor(cls):
214+
"""Lazy initialization of the ThreadPoolExecutor."""
215+
if cls._EXECUTOR is None:
216+
with cls._EXECUTOR_LOCK:
217+
if cls._EXECUTOR is None:
218+
# This thread pool has multiple workers because it is shared by all instances of Refreshable.
219+
cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
220+
return cls._EXECUTOR
191221

192-
def __init__(self, token=None):
193-
self._lock = threading.Lock() # to guard _token
222+
def __init__(self,
223+
token: Token = None,
224+
disable_async: bool = True,
225+
stale_duration: timedelta = _DEFAULT_STALE_DURATION):
226+
# Config properties
227+
self._stale_duration = stale_duration
228+
self._disable_async = disable_async
229+
# Lock
230+
self._lock = threading.Lock()
231+
# Non Thread safe properties. They should be accessed only when protected by the lock above.
194232
self._token = token
233+
self._is_refreshing = False
234+
self._refresh_err = False
195235

236+
# This is the main entry point for the Token. Do not access the token
237+
# using any of the internal functions.
196238
def token(self) -> Token:
197-
self._lock.acquire()
198-
try:
199-
if self._token and self._token.valid:
200-
return self._token
201-
self._token = self.refresh()
239+
"""Returns a valid token, blocking if async refresh is disabled."""
240+
with self._lock:
241+
if self._disable_async:
242+
return self._blocking_token()
243+
return self._async_token()
244+
245+
def _async_token(self) -> Token:
246+
"""
247+
Returns a token.
248+
If the token is stale, triggers an asynchronous refresh.
249+
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
250+
"""
251+
state = self._token_state()
252+
token = self._token
253+
254+
if state == _TokenState.FRESH:
255+
return token
256+
if state == _TokenState.STALE:
257+
self._trigger_async_refresh()
258+
return token
259+
return self._blocking_token()
260+
261+
def _token_state(self) -> _TokenState:
262+
"""Returns the current state of the token."""
263+
if not self._token or not self._token.valid:
264+
return _TokenState.EXPIRED
265+
if not self._token.expiry:
266+
return _TokenState.FRESH
267+
268+
lifespan = self._token.expiry - datetime.now()
269+
if lifespan < timedelta(seconds=0):
270+
return _TokenState.EXPIRED
271+
if lifespan < self._stale_duration:
272+
return _TokenState.STALE
273+
return _TokenState.FRESH
274+
275+
def _blocking_token(self) -> Token:
276+
"""Returns a token, blocking if necessary to refresh it."""
277+
state = self._token_state()
278+
# This is important to recover from potential previous failed attempts
279+
# to refresh the token asynchronously.
280+
self._refresh_err = False
281+
self._is_refreshing = False
282+
283+
# It's possible that the token got refreshed (either by a _blocking_refresh or
284+
# an _async_refresh call) while this particular call was waiting to acquire
285+
# the lock. This check avoids refreshing the token again in such cases.
286+
if state != _TokenState.EXPIRED:
202287
return self._token
203-
finally:
204-
self._lock.release()
288+
289+
self._token = self.refresh()
290+
return self._token
291+
292+
def _trigger_async_refresh(self):
293+
"""Starts an asynchronous refresh if none is in progress."""
294+
295+
def _refresh_internal():
296+
new_token: Token = None
297+
try:
298+
new_token = self.refresh()
299+
except Exception as e:
300+
# This happens on a thread, so we don't want to propagate the error.
301+
# Instead, if there is no new_token for any reason, we will disable async refresh below
302+
# But we will do it inside the lock.
303+
logger.warning(f'Tried to refresh token asynchronously, but failed: {e}')
304+
305+
with self._lock:
306+
if new_token is not None:
307+
self._token = new_token
308+
else:
309+
self._refresh_err = True
310+
self._is_refreshing = False
311+
312+
# The token may have been refreshed by another thread.
313+
if self._token_state() == _TokenState.FRESH:
314+
return
315+
if not self._is_refreshing and not self._refresh_err:
316+
self._is_refreshing = True
317+
Refreshable._get_executor().submit(_refresh_internal)
205318

206319
@abstractmethod
207320
def refresh(self) -> Token:
@@ -295,7 +408,7 @@ def __init__(self,
295408
super().__init__(token)
296409

297410
def as_dict(self) -> dict:
298-
return {'token': self._token.as_dict()}
411+
return {'token': self.token().as_dict()}
299412

300413
@staticmethod
301414
def from_dict(raw: dict,

tests/test_refreshable.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import time
2+
from datetime import datetime, timedelta
3+
from time import sleep
4+
from typing import Callable
5+
6+
from databricks.sdk.oauth import Refreshable, Token
7+
8+
9+
class _MockRefreshable(Refreshable):
10+
11+
def __init__(self,
12+
disable_async,
13+
token=None,
14+
stale_duration=timedelta(seconds=60),
15+
refresh_effect: Callable[[], Token] = None):
16+
super().__init__(token, disable_async, stale_duration)
17+
self._refresh_effect = refresh_effect
18+
self._refresh_count = 0
19+
20+
def refresh(self) -> Token:
21+
if self._refresh_effect:
22+
self._token = self._refresh_effect()
23+
self._refresh_count += 1
24+
return self._token
25+
26+
27+
def fail() -> Token:
28+
raise Exception("Simulated token refresh failure")
29+
30+
31+
def static_token(token: Token, wait: int = 0) -> Callable[[], Token]:
32+
33+
def f() -> Token:
34+
time.sleep(wait)
35+
return token
36+
37+
return f
38+
39+
40+
def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[], None]):
41+
"""
42+
Create a refresh function that blocks until unblock is called.
43+
44+
Param:
45+
token: the token that will be returned
46+
47+
Returns:
48+
A tuple containing the refresh function and the unblock function.
49+
50+
"""
51+
blocking = True
52+
53+
def refresh():
54+
while blocking:
55+
sleep(0.1)
56+
return token
57+
58+
def unblock():
59+
nonlocal blocking
60+
blocking = False
61+
62+
return refresh, unblock
63+
64+
65+
def test_disable_async_stale_does_not_refresh():
66+
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), )
67+
r = _MockRefreshable(token=stale_token, disable_async=True, refresh_effect=fail)
68+
result = r.token()
69+
assert r._refresh_count == 0
70+
assert result == stale_token
71+
72+
73+
def test_disable_async_no_token_does_refresh():
74+
token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), )
75+
r = _MockRefreshable(token=None, disable_async=True, refresh_effect=static_token(token))
76+
result = r.token()
77+
assert r._refresh_count == 1
78+
assert result == token
79+
80+
81+
def test_disable_async_no_expiration_does_not_refresh():
82+
non_expiring_token = Token(access_token="access_token", )
83+
r = _MockRefreshable(token=non_expiring_token, disable_async=True, refresh_effect=fail)
84+
result = r.token()
85+
assert r._refresh_count == 0
86+
assert result == non_expiring_token
87+
88+
89+
def test_disable_async_fresh_does_not_refresh():
90+
# Create a token that is already stale. If async is disabled, the token should not be refreshed.
91+
token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
92+
r = _MockRefreshable(token=token, disable_async=True, refresh_effect=fail)
93+
result = r.token()
94+
assert r._refresh_count == 0
95+
assert result == token
96+
97+
98+
def test_disable_async_expired_does_refresh():
99+
expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), )
100+
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
101+
# Add one second to the refresh time to ensure that the call is blocking.
102+
# If the call is not blocking, the wait time will ensure that the
103+
# old token is returned.
104+
r = _MockRefreshable(token=expired_token,
105+
disable_async=True,
106+
refresh_effect=static_token(new_token, wait=1))
107+
result = r.token()
108+
assert r._refresh_count == 1
109+
assert result == new_token
110+
111+
112+
def test_expired_does_refresh():
113+
expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), )
114+
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
115+
# Add one second to the refresh time to ensure that the call is blocking.
116+
# If the call is not blocking, the wait time will ensure that the
117+
# old token is returned.
118+
r = _MockRefreshable(token=expired_token,
119+
disable_async=False,
120+
refresh_effect=static_token(new_token, wait=1))
121+
result = r.token()
122+
assert r._refresh_count == 1
123+
assert result == new_token
124+
125+
126+
def test_stale_does_refresh_async():
127+
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), )
128+
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
129+
# Add one second to the refresh to avoid race conditions.
130+
# Without it, the new token may be returned in some cases.
131+
refresh, unblock = blocking_refresh(new_token)
132+
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh)
133+
result = r.token()
134+
# NOTE: Do not check for refresh count here, since the
135+
assert result == stale_token
136+
assert r._refresh_count == 0
137+
# Unblock the refresh and wait
138+
unblock()
139+
time.sleep(2)
140+
# Call again and check that you get the new token
141+
result = r.token()
142+
assert result == new_token
143+
# Ensure that all calls have completed
144+
time.sleep(0.1)
145+
assert r._refresh_count == 1
146+
147+
148+
def test_no_token_does_refresh():
149+
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
150+
# Add one second to the refresh time to ensure that the call is blocking.
151+
# If the call is not blocking, the wait time will ensure that the
152+
# token is not returned.
153+
r = _MockRefreshable(token=None, disable_async=False, refresh_effect=static_token(new_token, wait=1))
154+
result = r.token()
155+
assert r._refresh_count == 1
156+
assert result == new_token
157+
158+
159+
def test_fresh_does_not_refresh():
160+
fresh_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
161+
r = _MockRefreshable(token=fresh_token, disable_async=False, refresh_effect=fail)
162+
result = r.token()
163+
assert r._refresh_count == 0
164+
assert result == fresh_token
165+
166+
167+
def test_multiple_calls_dont_start_many_threads():
168+
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), )
169+
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
170+
refresh, unblock = blocking_refresh(new_token)
171+
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh)
172+
# Call twice. The second call should not start a new thread.
173+
result = r.token()
174+
assert result == stale_token
175+
result = r.token()
176+
assert result == stale_token
177+
unblock()
178+
# Wait for the refresh to complete
179+
time.sleep(1)
180+
result = r.token()
181+
# Check that only one refresh was called
182+
assert r._refresh_count == 1
183+
assert result == new_token
184+
185+
186+
def test_async_failure_disables_async():
187+
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), )
188+
new_token = Token(access_token="new_token", expiry=datetime.now() + timedelta(seconds=300), )
189+
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=fail)
190+
# The call should fail and disable async refresh,
191+
# but the exception will be catch inside the tread.
192+
result = r.token()
193+
assert result == stale_token
194+
# Give time to the async refresh to fail
195+
time.sleep(1)
196+
assert r._refresh_err
197+
# Now, the refresh should be blocking.
198+
# Blocking refresh only happens for expired, not stale.
199+
# Therefore, the next call should return the stale token.
200+
r._refresh_effect = static_token(new_token, wait=1)
201+
result = r.token()
202+
assert result == stale_token
203+
# Wait to be sure no async thread was started
204+
time.sleep(1)
205+
assert r._refresh_count == 0
206+
207+
# Inject an expired token.
208+
expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), )
209+
r._token = expired_token
210+
211+
# This should be blocking and return the new token.
212+
result = r.token()
213+
assert r._refresh_count == 1
214+
assert result == new_token
215+
# The refresh error should be cleared.
216+
assert not r._refresh_err

0 commit comments

Comments
 (0)