Skip to content

Commit 95e8633

Browse files
committed
Mega block
1 parent 21c999f commit 95e8633

File tree

1 file changed

+35
-38
lines changed

1 file changed

+35
-38
lines changed

databricks/sdk/oauth.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -228,26 +228,28 @@ def __init__(self,
228228
self._disable_async = disable_async
229229
# Lock
230230
self._lock = threading.Lock()
231-
# Non Thread safe properties. Protected by the lock above.
231+
# Non Thread safe properties. They should be accessed only when protected by the lock above.
232232
self._token = token
233233
self._is_refreshing = False
234234
self._refresh_err = False
235235

236+
# This is the main entry point for the Token. Do not access the token
237+
# using any of the internal functions.
236238
def token(self) -> Token:
237239
"""Returns a valid token, blocking if async refresh is disabled."""
238-
if self._disable_async:
239-
return self._blocking_token()
240-
return self._async_token()
240+
with self._lock:
241+
if self._disable_async:
242+
return self._blocking_token()
243+
return self._async_token()
241244

242245
def _async_token(self) -> Token:
243246
"""
244247
Returns a token.
245248
If the token is stale, triggers an asynchronous refresh.
246249
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
247250
"""
248-
with self._lock:
249-
state = Refreshable._token_state(self._token, self._stale_duration)
250-
token = self._token
251+
state = self._token_state()
252+
token = self._token
251253

252254
if state == _TokenState.FRESH:
253255
return token
@@ -256,41 +258,38 @@ def _async_token(self) -> Token:
256258
return token
257259
return self._blocking_token()
258260

259-
# This is a class method and we pass the token to avoid
260-
# concurrency issues and deadlocks.
261-
@classmethod
262-
def _token_state(cls, token: Token, stale_duration: timedelta) -> _TokenState:
261+
262+
def _token_state(self) -> _TokenState:
263263
"""Returns the current state of the token."""
264-
if not token or not token.valid:
264+
if not self._token or not self._token.valid:
265265
return _TokenState.EXPIRED
266-
if not token.expiry:
266+
if not self._token.expiry:
267267
return _TokenState.FRESH
268268

269-
lifespan = token.expiry - datetime.now()
269+
lifespan = self._token.expiry - datetime.now()
270270
if lifespan < timedelta(seconds=0):
271271
return _TokenState.EXPIRED
272-
if lifespan < stale_duration:
272+
if lifespan < self._stale_duration:
273273
return _TokenState.STALE
274274
return _TokenState.FRESH
275275

276276
def _blocking_token(self) -> Token:
277277
"""Returns a token, blocking if necessary to refresh it."""
278-
with self._lock:
279-
state = Refreshable._token_state(self._token, self._stale_duration)
280-
# This is important to recover from potential previous failed attempts
281-
# to refresh the token asynchronously.
282-
self._refresh_err = False
283-
self._is_refreshing = False
284-
285-
# It's possible that the token got refreshed (either by a _blocking_refresh or
286-
# an _async_refresh call) while this particular call was waiting to acquire
287-
# the lock. This check avoids refreshing the token again in such cases.
288-
if state != _TokenState.EXPIRED:
289-
return self._token
290-
291-
self._token = self.refresh()
278+
state = self._token_state()
279+
# This is important to recover from potential previous failed attempts
280+
# to refresh the token asynchronously.
281+
self._refresh_err = False
282+
self._is_refreshing = False
283+
284+
# It's possible that the token got refreshed (either by a _blocking_refresh or
285+
# an _async_refresh call) while this particular call was waiting to acquire
286+
# the lock. This check avoids refreshing the token again in such cases.
287+
if state != _TokenState.EXPIRED:
292288
return self._token
293289

290+
self._token = self.refresh()
291+
return self._token
292+
294293
def _trigger_async_refresh(self):
295294
"""Starts an asynchronous refresh if none is in progress."""
296295

@@ -311,14 +310,12 @@ def _refresh_internal():
311310
self._refresh_err = True
312311
self._is_refreshing = False
313312

314-
with self._lock:
315-
state = Refreshable._token_state(self._token, self._stale_duration)
316-
# The token may have been refreshed by another thread.
317-
if state == _TokenState.FRESH:
318-
return
319-
if not self._is_refreshing and not self._refresh_err:
320-
self._is_refreshing = True
321-
Refreshable._get_executor().submit(_refresh_internal)
313+
# The token may have been refreshed by another thread.
314+
if self._token_state() == _TokenState.FRESH:
315+
return
316+
if not self._is_refreshing and not self._refresh_err:
317+
self._is_refreshing = True
318+
Refreshable._get_executor().submit(_refresh_internal)
322319

323320
@abstractmethod
324321
def refresh(self) -> Token:
@@ -412,7 +409,7 @@ def __init__(self,
412409
super().__init__(token)
413410

414411
def as_dict(self) -> dict:
415-
return {'token': self._token.as_dict()}
412+
return {'token': self.token().as_dict()}
416413

417414
@staticmethod
418415
def from_dict(raw: dict,

0 commit comments

Comments
 (0)