Skip to content

Commit 011890d

Browse files
committed
WIP
1 parent 998a117 commit 011890d

File tree

2 files changed

+342
-6
lines changed

2 files changed

+342
-6
lines changed

databricks/sdk/oauth.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from abc import abstractmethod
1212
from dataclasses import dataclass
1313
from datetime import datetime, timedelta
14+
from enum import Enum
1415
from http.server import BaseHTTPRequestHandler, HTTPServer
1516
from typing import Any, Dict, List, Optional
17+
from concurrent.futures import ThreadPoolExecutor
1618

1719
import requests
1820
import requests.auth
@@ -186,22 +188,112 @@ def retrieve_token(client_id,
186188
except Exception as e:
187189
raise NotImplementedError(f"Not supported yet: {e}")
188190

191+
class _TokenState(Enum):
192+
"""
193+
tokenState represents the state of the token. Each token can be in one of
194+
the following three states:
195+
- FRESH: The token is valid.
196+
- STALE: The token is valid but will expire soon.
197+
- EXPIRED: The token has expired and cannot be used.
198+
199+
Token state through time:
200+
issue time expiry time
201+
v v
202+
| fresh | stale | expired -> time
203+
| valid |
204+
"""
205+
FRESH = 1 # The token is valid.
206+
STALE = 2 # The token is valid but will expire soon.
207+
EXPIRED = 3 # The token has expired and cannot be used.
208+
189209

190210
class Refreshable(TokenSource):
211+
_executor = ThreadPoolExecutor(max_workers=10)
212+
_default_stale_duration = 3
191213

192-
def __init__(self, token=None):
214+
def __init__(self, token=None, disable_async = True, stale_duration=timedelta(minutes=_default_stale_duration)):
193215
self._lock = threading.Lock() # to guard _token
194216
self._token = token
217+
self._stale_duration = stale_duration
218+
self._disable_async = disable_async
219+
self._is_refreshing = False
220+
self._refresh_err = False
195221

196-
def token(self) -> Token:
222+
def token(self, blocking=False) -> Token:
223+
if self._disable_async:
224+
return self._blocking_token()
225+
return self._async_token()
226+
227+
def _async_token(self) -> Token:
197228
self._lock.acquire()
198-
try:
199-
if self._token and self._token.valid:
229+
token_state = self._token_state()
230+
token = self._token
231+
self._lock.release()
232+
match token_state:
233+
case _TokenState.FRESH:
234+
return token
235+
case _TokenState.STALE:
236+
self._trigger_async_refresh()
237+
return token
238+
case _: #Expired
239+
return self._blocking_token()
240+
241+
242+
def _token_state(self) -> _TokenState:
243+
"""
244+
Returns the state of the token.
245+
"""
246+
# Invalid tokens are considered expired.
247+
if not self._token or not self._token.valid:
248+
return _TokenState.EXPIRED
249+
# Tokens without an expiry are considered always.
250+
if not self._token.expiry:
251+
return _TokenState.FRESH
252+
lifespan = self._token.expiry - datetime.now()
253+
if lifespan < timedelta(seconds=0):
254+
return _TokenState.EXPIRED
255+
if lifespan < self._stale_duration:
256+
return _TokenState.STALE
257+
return _TokenState.FRESH
258+
259+
def _blocking_token(self) -> Token:
260+
261+
# The lock is kept for the entire operation to ensure that only one
262+
# refresh operation is running at a time.
263+
with self._lock:
264+
# This is important to recover from potential previous failed attempts
265+
# to refresh the token asynchronously, see declaration of refresh_err for
266+
# more information.
267+
self._refresh_err = False
268+
self._is_refreshing = False
269+
270+
# It's possible that the token got refreshed (either by a _blocking_refresh or
271+
# an _async_refresh call) while this particular call was waiting to acquire
272+
# the lock. This check avoids refreshing the token again in such cases.
273+
if self._token_state() != _TokenState.EXPIRED:
200274
return self._token
275+
276+
# Refresh the token
201277
self._token = self.refresh()
202278
return self._token
203-
finally:
204-
self._lock.release()
279+
280+
281+
def _trigger_async_refresh(self):
282+
# Note: this is not thread safe.
283+
# Only call it inside the lock.
284+
def _refresh_internal():
285+
try:
286+
self._token = self.refresh()
287+
except Exception:
288+
self._refresh_err = True
289+
finally:
290+
self._is_refreshing = False
291+
# The lock is kept for the entire operation to ensure that only one
292+
# refresh operation is running at a time.
293+
with self._lock:
294+
if not self._is_refreshing and not self._refresh_err:
295+
self._is_refreshing = True
296+
self._executor.submit(_refresh_internal)
205297

206298
@abstractmethod
207299
def refresh(self) -> Token:

tests/test_refreshable.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import time
2+
from time import sleep
3+
from typing import Callable
4+
5+
import pytest
6+
7+
from datetime import datetime, timedelta
8+
9+
from databricks.sdk.oauth import Refreshable, Token
10+
11+
12+
class _MockRefreshable(Refreshable):
13+
14+
def __init__(self, disable_async, token=None, stale_duration=timedelta(seconds=60), refresh_effect: Callable[[], Token]=None):
15+
super().__init__(token, disable_async, stale_duration)
16+
self._refresh_effect = refresh_effect
17+
self._refresh_count = 0
18+
19+
def refresh(self) -> Token:
20+
if self._refresh_effect:
21+
self._token = self._refresh_effect()
22+
self._refresh_count += 1
23+
return self._token
24+
25+
def fail() -> Token:
26+
raise Exception("Failed to refresh token")
27+
28+
def static_token(token: Token, wait: int=0) -> Callable[[], Token]:
29+
def f() -> Token:
30+
time.sleep(wait)
31+
return token
32+
return f
33+
34+
def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[],None]):
35+
"""
36+
Create a refresh function that blocks until unblock is called.
37+
38+
Param:
39+
token: the token that will be returned
40+
41+
Returns:
42+
A tuple containing the refresh function and the unblock function.
43+
44+
"""
45+
blocking = True
46+
def refresh():
47+
while blocking:
48+
sleep(0.1)
49+
return token
50+
def unblock():
51+
nonlocal blocking
52+
blocking = False
53+
return refresh, unblock
54+
55+
56+
def test_disable_async_stale_does_not_refresh():
57+
stale_token = Token(
58+
access_token="access_token",
59+
expiry=datetime.now() + timedelta(seconds=50),
60+
)
61+
r = _MockRefreshable(token=stale_token, disable_async=True, refresh_effect=fail)
62+
result = r.token()
63+
assert r._refresh_count == 0
64+
assert result == stale_token
65+
66+
def test_disable_async_no_token_does_refresh():
67+
token = Token(
68+
access_token="access_token",
69+
expiry=datetime.now() + timedelta(seconds=50),
70+
)
71+
r = _MockRefreshable(token=None, disable_async=True, refresh_effect=static_token(token))
72+
result = r.token()
73+
assert r._refresh_count == 1
74+
assert result == token
75+
76+
def test_disable_async_no_expiration_does_not_refresh():
77+
non_expiring_token = Token(
78+
access_token="access_token",
79+
)
80+
r = _MockRefreshable(token=non_expiring_token, disable_async=True, refresh_effect=fail)
81+
result = r.token()
82+
assert r._refresh_count == 0
83+
assert result == non_expiring_token
84+
85+
def test_disable_async_fresh_does_not_refresh():
86+
# Create a token that is already stale. If async is disabled, the token should not be refreshed.
87+
token = Token(
88+
access_token="access_token",
89+
expiry=datetime.now() + timedelta(seconds=300),
90+
)
91+
r = _MockRefreshable(token=token, disable_async=True, refresh_effect=fail)
92+
result = r.token()
93+
assert r._refresh_count == 0
94+
assert result == token
95+
96+
def test_disable_async_expired_does_refresh():
97+
expired_token = Token(
98+
access_token="access_token",
99+
expiry=datetime.now() - timedelta(seconds=300),
100+
)
101+
new_token = Token(
102+
access_token="access_token",
103+
expiry=datetime.now() + timedelta(seconds=300),
104+
)
105+
# Add one second to the refresh time to ensure that the call is blocking.
106+
# If the call is not blocking, the wait time will ensure that the
107+
# old token is returned.
108+
r = _MockRefreshable(token=expired_token, disable_async=True, refresh_effect=static_token(new_token, wait=1))
109+
result = r.token()
110+
assert r._refresh_count == 1
111+
assert result == new_token
112+
113+
def test_expired_does_refresh():
114+
expired_token = Token(
115+
access_token="access_token",
116+
expiry=datetime.now() - timedelta(seconds=300),
117+
)
118+
new_token = Token(
119+
access_token="access_token",
120+
expiry=datetime.now() + timedelta(seconds=300),
121+
)
122+
# Add one second to the refresh time to ensure that the call is blocking.
123+
# If the call is not blocking, the wait time will ensure that the
124+
# old token is returned.
125+
r = _MockRefreshable(token=expired_token, disable_async=False, refresh_effect=static_token(new_token, wait=1))
126+
result = r.token()
127+
assert r._refresh_count == 1
128+
assert result == new_token
129+
130+
def test_stale_does_refresh_async():
131+
stale_token = Token(
132+
access_token="access_token",
133+
expiry=datetime.now() + timedelta(seconds=50),
134+
)
135+
new_token = Token(
136+
access_token="access_token",
137+
expiry=datetime.now() + timedelta(seconds=300),
138+
)
139+
# Add one second to the refresh to avoid race conditions.
140+
# Without it, the new token may be returned in some cases.
141+
refresh, unblock = blocking_refresh(new_token)
142+
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh)
143+
result = r.token()
144+
# NOTE: Do not check for refresh count here, since the
145+
assert result == stale_token
146+
assert r._refresh_count == 0
147+
# Unblock the refresh and wait
148+
unblock()
149+
time.sleep(2)
150+
# Call again and check that you get the new token
151+
result = r.token()
152+
assert result == new_token
153+
# Ensure that all calls have completed
154+
time.sleep(0.1)
155+
assert r._refresh_count == 1
156+
157+
158+
def test_no_token_does_refresh():
159+
new_token = Token(
160+
access_token="access_token",
161+
expiry=datetime.now() + timedelta(seconds=300),
162+
)
163+
# Add one second to the refresh time to ensure that the call is blocking.
164+
# If the call is not blocking, the wait time will ensure that the
165+
# token is not returned.
166+
r = _MockRefreshable(token=None, disable_async=False, refresh_effect=static_token(new_token, wait=1))
167+
result = r.token()
168+
assert r._refresh_count == 1
169+
assert result == new_token
170+
171+
def test_fresh_does_not_refresh():
172+
fresh_token = Token(
173+
access_token="access_token",
174+
expiry=datetime.now() + timedelta(seconds=300),
175+
)
176+
r = _MockRefreshable(token=fresh_token, disable_async=False, refresh_effect=fail)
177+
result = r.token()
178+
assert r._refresh_count == 0
179+
assert result == fresh_token
180+
181+
def test_multiple_calls_dont_start_many_threads():
182+
stale_token = Token(
183+
access_token="access_token",
184+
expiry=datetime.now() + timedelta(seconds=59),
185+
)
186+
new_token = Token(
187+
access_token="access_token",
188+
expiry=datetime.now() + timedelta(seconds=300),
189+
)
190+
refresh, unblock = blocking_refresh(new_token)
191+
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh)
192+
# Call twice. The second call should not start a new thread.
193+
result = r.token()
194+
assert result == stale_token
195+
result = r.token()
196+
assert result == stale_token
197+
unblock()
198+
# Wait for the refresh to complete
199+
time.sleep(1)
200+
result = r.token()
201+
# Check that only one refresh was called
202+
assert r._refresh_count == 1
203+
assert result == new_token
204+
205+
def test_async_failure_disables_async():
206+
stale_token = Token(
207+
access_token="access_token",
208+
expiry=datetime.now() + timedelta(seconds=59),
209+
)
210+
new_token = Token(
211+
access_token="new_token",
212+
expiry=datetime.now() + timedelta(seconds=300),
213+
)
214+
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=fail)
215+
# The call should fail and disable async refresh,
216+
# but the exception will be catch inside the tread.
217+
result = r.token()
218+
assert result == stale_token
219+
# Give time to the async refresh to fail
220+
time.sleep(1)
221+
assert r._refresh_err
222+
# Now, the refresh should be blocking.
223+
# Blocking refresh only happens for expired, not stale.
224+
# Therefore, the next call should return the stale token.
225+
r._refresh_effect = static_token(new_token, wait=1)
226+
result = r.token()
227+
assert result == stale_token
228+
# Wait to be sure no async thread was started
229+
time.sleep(1)
230+
assert r._refresh_count == 0
231+
232+
# Inject an expired token.
233+
expired_token = Token(
234+
access_token="access_token",
235+
expiry=datetime.now() - timedelta(seconds=300),
236+
)
237+
r._token = expired_token
238+
239+
# This should be blocking and return the new token.
240+
result = r.token()
241+
assert r._refresh_count == 1
242+
assert result == new_token
243+
# The refresh error should be cleared.
244+
assert not r._refresh_err

0 commit comments

Comments
 (0)