Skip to content

Commit e732ae8

Browse files
committed
Add support for async token refresh
1 parent 03312b3 commit e732ae8

File tree

9 files changed

+62
-164
lines changed

9 files changed

+62
-164
lines changed

NEXT_CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44

55
### New Features and Improvements
66
* Update Jobs service to use API 2.2 ([#913](https://github.com/databricks/databricks-sdk-py/pull/913)).
7+
* [Experimental] Add support for async token refresh ([#916](https://github.com/databricks/databricks-sdk-py/pull/916)).
8+
This can be enabled with by setting the following setting:
9+
```
10+
export DATABRICKS_ENABLE_EXPERIMENTAL_ASYNC_TOKEN_REFRESH=1.
11+
```
12+
This feature and its setting are experimental and may be removed in future releases.
713

814
### Bug Fixes
915

databricks/sdk/__init__.py

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databricks/sdk/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ class Config:
9494
max_connections_per_pool: int = ConfigAttribute()
9595
databricks_environment: Optional[DatabricksEnvironment] = None
9696

97+
enable_experimental_async_token_refresh: bool = ConfigAttribute(
98+
env="DATABRICKS_ENABLE_EXPERIMENTAL_ASYNC_TOKEN_REFRESH"
99+
)
100+
97101
enable_experimental_files_api_client: bool = ConfigAttribute(env="DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT")
98102
files_api_client_download_max_total_recovers = None
99103
files_api_client_download_max_total_recovers_without_progressing = 1

databricks/sdk/credentials_provider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
191191
token_url=oidc.token_endpoint,
192192
scopes=["all-apis"],
193193
use_header=True,
194+
disable_async=not cfg.enable_experimental_async_token_refresh,
194195
)
195196

196197
def inner() -> Dict[str, str]:
@@ -290,6 +291,7 @@ def token_source_for(resource: str) -> TokenSource:
290291
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
291292
endpoint_params={"resource": resource},
292293
use_params=True,
294+
disable_async=not cfg.enable_experimental_async_token_refresh,
293295
)
294296

295297
_ensure_host_present(cfg, token_source_for)
@@ -355,6 +357,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
355357
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
356358
endpoint_params=params,
357359
use_params=True,
360+
disable_async=not cfg.enable_experimental_async_token_refresh,
358361
)
359362

360363
def refreshed_headers() -> Dict[str, str]:
@@ -458,8 +461,9 @@ def __init__(
458461
token_type_field: str,
459462
access_token_field: str,
460463
expiry_field: str,
464+
disable_async: bool = True,
461465
):
462-
super().__init__()
466+
super().__init__(disable_async=disable_async)
463467
self._cmd = cmd
464468
self._token_type_field = token_type_field
465469
self._access_token_field = access_token_field
@@ -690,6 +694,7 @@ def __init__(self, cfg: "Config"):
690694
token_type_field="token_type",
691695
access_token_field="access_token",
692696
expiry_field="expiry",
697+
disable_async=not cfg.enable_experimental_async_token_refresh,
693698
)
694699

695700
@staticmethod

databricks/sdk/data_plane.py

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import threading
44
from dataclasses import dataclass
5-
from typing import Callable, List, Optional
5+
from typing import Callable, Optional
66
from urllib import parse
77

88
from databricks.sdk import oauth
@@ -88,61 +88,3 @@ class DataPlaneDetails:
8888
"""URL used to query the endpoint through the DataPlane."""
8989
token: Token
9090
"""Token to query the DataPlane endpoint."""
91-
92-
93-
## Old implementation. #TODO: Remove after the new implementation is used
94-
95-
96-
class DataPlaneService:
97-
"""Helper class to fetch and manage DataPlane details."""
98-
99-
from .service.serving import DataPlaneInfo
100-
101-
def __init__(self):
102-
self._data_plane_info = {}
103-
self._tokens = {}
104-
self._lock = threading.Lock()
105-
106-
def get_data_plane_details(
107-
self,
108-
method: str,
109-
params: List[str],
110-
info_getter: Callable[[], DataPlaneInfo],
111-
refresh: Callable[[str], Token],
112-
):
113-
"""Get and cache information required to query a Data Plane endpoint using the provided methods.
114-
115-
Returns a cached DataPlaneDetails if the details have already been fetched previously and are still valid.
116-
If not, it uses the provided functions to fetch the details.
117-
118-
:param method: method name. Used to construct a unique key for the cache.
119-
:param params: path params used in the "get" operation which uniquely determine the object. Used to construct a unique key for the cache.
120-
:param info_getter: function which returns the DataPlaneInfo. It will only be called if the information is not already present in the cache.
121-
:param refresh: function to refresh the token. It will only be called if the token is missing or expired.
122-
"""
123-
all_elements = params.copy()
124-
all_elements.insert(0, method)
125-
map_key = "/".join(all_elements)
126-
info = self._data_plane_info.get(map_key)
127-
if not info:
128-
self._lock.acquire()
129-
try:
130-
info = self._data_plane_info.get(map_key)
131-
if not info:
132-
info = info_getter()
133-
self._data_plane_info[map_key] = info
134-
finally:
135-
self._lock.release()
136-
137-
token = self._tokens.get(map_key)
138-
if not token or not token.valid:
139-
self._lock.acquire()
140-
token = self._tokens.get(map_key)
141-
try:
142-
if not token or not token.valid:
143-
token = refresh(info.authorization_details)
144-
self._tokens[map_key] = token
145-
finally:
146-
self._lock.release()
147-
148-
return DataPlaneDetails(endpoint_url=info.endpoint_url, token=token)

databricks/sdk/oauth.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,16 @@ def __init__(
426426
client_id: str,
427427
client_secret: str = None,
428428
redirect_url: str = None,
429+
disable_async: bool = True,
429430
):
430431
self._token_endpoint = token_endpoint
431432
self._client_id = client_id
432433
self._client_secret = client_secret
433434
self._redirect_url = redirect_url
434-
super().__init__(token)
435+
super().__init__(
436+
token=token,
437+
disable_async=disable_async,
438+
)
435439

436440
def as_dict(self) -> dict:
437441
return {"token": self.token().as_dict()}
@@ -708,9 +712,10 @@ class ClientCredentials(Refreshable):
708712
scopes: List[str] = None
709713
use_params: bool = False
710714
use_header: bool = False
715+
disable_async: bool = True
711716

712717
def __post_init__(self):
713-
super().__init__()
718+
super().__init__(disable_async=self.disable_async)
714719

715720
def refresh(self) -> Token:
716721
params = {"grant_type": "client_credentials"}

databricks/sdk/service/serving.py

Lines changed: 25 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/integration/test_data_plane.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,10 @@ def test_data_plane_token_source(ucws, env_or_skip):
1313
dp_token = ts.token(info.endpoint_url, info.authorization_details)
1414

1515
assert dp_token.valid
16+
17+
18+
def test_model_serving_data_plane(ucws, env_or_skip):
19+
endpoint = env_or_skip("SERVING_ENDPOINT_NAME")
20+
serving_endpoints = ucws.serving_endpoints_data_plane
21+
response = serving_endpoints.query(name=endpoint, dataframe_records=[{"col": 1.0}])
22+
assert response is not None

tests/test_data_plane.py

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
from urllib import parse
44

55
from databricks.sdk import data_plane, oauth
6-
from databricks.sdk.data_plane import DataPlaneService
76
from databricks.sdk.oauth import Token
8-
from databricks.sdk.service.serving import DataPlaneInfo
97

108
cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
119
dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
@@ -63,88 +61,11 @@ def token(self):
6361

6462
def test_token_source_get_token_existing(config):
6563
another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
66-
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True)
64+
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(cp_token), disable_async=True)
6765
token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token)
6866

6967
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
7068
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
7169

7270
retrieve_token.assert_not_called()
7371
assert result_token.access_token == another_token.access_token
74-
75-
76-
## These tests are for the old implementation. #TODO: Remove after the new implementation is used
77-
78-
info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url")
79-
80-
token = Token(
81-
access_token="token",
82-
token_type="type",
83-
expiry=datetime.now() + timedelta(hours=1),
84-
)
85-
86-
87-
class MockRefresher:
88-
89-
def __init__(self, expected: str):
90-
self._expected = expected
91-
92-
def __call__(self, auth_details: str) -> Token:
93-
assert self._expected == auth_details
94-
return token
95-
96-
97-
def throw_exception():
98-
raise Exception("Expected value to be cached")
99-
100-
101-
def test_not_cached():
102-
data_plane = DataPlaneService()
103-
res = data_plane.get_data_plane_details(
104-
"method",
105-
["params"],
106-
lambda: info,
107-
lambda a: MockRefresher(info.authorization_details).__call__(a),
108-
)
109-
assert res.endpoint_url == info.endpoint_url
110-
assert res.token == token
111-
112-
113-
def test_token_expired():
114-
expired = Token(
115-
access_token="expired",
116-
token_type="type",
117-
expiry=datetime.now() + timedelta(hours=-1),
118-
)
119-
data_plane = DataPlaneService()
120-
data_plane._tokens["method/params"] = expired
121-
res = data_plane.get_data_plane_details(
122-
"method",
123-
["params"],
124-
lambda: info,
125-
lambda a: MockRefresher(info.authorization_details).__call__(a),
126-
)
127-
assert res.endpoint_url == info.endpoint_url
128-
assert res.token == token
129-
130-
131-
def test_info_cached():
132-
data_plane = DataPlaneService()
133-
data_plane._data_plane_info["method/params"] = info
134-
res = data_plane.get_data_plane_details(
135-
"method",
136-
["params"],
137-
throw_exception,
138-
lambda a: MockRefresher(info.authorization_details).__call__(a),
139-
)
140-
assert res.endpoint_url == info.endpoint_url
141-
assert res.token == token
142-
143-
144-
def test_token_cached():
145-
data_plane = DataPlaneService()
146-
data_plane._data_plane_info["method/params"] = info
147-
data_plane._tokens["method/params"] = token
148-
res = data_plane.get_data_plane_details("method", ["params"], throw_exception, throw_exception)
149-
assert res.endpoint_url == info.endpoint_url
150-
assert res.token == token

0 commit comments

Comments
 (0)