Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
## Release v0.46.0

### New Features and Improvements
* [Experimental] Add support for async token refresh ([#916](https://github.com/databricks/databricks-sdk-py/pull/916)).
This can be enabled with by setting the following setting:
```
export DATABRICKS_ENABLE_EXPERIMENTAL_ASYNC_TOKEN_REFRESH=1.
```
This feature and its setting are experimental and may be removed in future releases.

### Bug Fixes

Expand Down
6 changes: 5 additions & 1 deletion databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class Config:
max_connections_per_pool: int = ConfigAttribute()
databricks_environment: Optional[DatabricksEnvironment] = None

enable_experimental_async_token_refresh: bool = ConfigAttribute(
env="DATABRICKS_ENABLE_EXPERIMENTAL_ASYNC_TOKEN_REFRESH"
)

enable_experimental_files_api_client: bool = ConfigAttribute(env="DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT")
files_api_client_download_max_total_recovers = None
files_api_client_download_max_total_recovers_without_progressing = 1
Expand Down
7 changes: 6 additions & 1 deletion databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
token_url=oidc.token_endpoint,
scopes=["all-apis"],
use_header=True,
disable_async=not cfg.enable_experimental_async_token_refresh,
)

def inner() -> Dict[str, str]:
Expand Down Expand Up @@ -290,6 +291,7 @@ def token_source_for(resource: str) -> TokenSource:
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
endpoint_params={"resource": resource},
use_params=True,
disable_async=not cfg.enable_experimental_async_token_refresh,
)

_ensure_host_present(cfg, token_source_for)
Expand Down Expand Up @@ -355,6 +357,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
endpoint_params=params,
use_params=True,
disable_async=not cfg.enable_experimental_async_token_refresh,
)

def refreshed_headers() -> Dict[str, str]:
Expand Down Expand Up @@ -458,8 +461,9 @@ def __init__(
token_type_field: str,
access_token_field: str,
expiry_field: str,
disable_async: bool = True,
):
super().__init__()
super().__init__(disable_async=disable_async)
self._cmd = cmd
self._token_type_field = token_type_field
self._access_token_field = access_token_field
Expand Down Expand Up @@ -690,6 +694,7 @@ def __init__(self, cfg: "Config"):
token_type_field="token_type",
access_token_field="access_token",
expiry_field="expiry",
disable_async=not cfg.enable_experimental_async_token_refresh,
)

@staticmethod
Expand Down
60 changes: 1 addition & 59 deletions databricks/sdk/data_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import threading
from dataclasses import dataclass
from typing import Callable, List, Optional
from typing import Callable, Optional
from urllib import parse

from databricks.sdk import oauth
Expand Down Expand Up @@ -88,61 +88,3 @@ class DataPlaneDetails:
"""URL used to query the endpoint through the DataPlane."""
token: Token
"""Token to query the DataPlane endpoint."""


## Old implementation. #TODO: Remove after the new implementation is used


class DataPlaneService:
"""Helper class to fetch and manage DataPlane details."""

from .service.serving import DataPlaneInfo

def __init__(self):
self._data_plane_info = {}
self._tokens = {}
self._lock = threading.Lock()

def get_data_plane_details(
self,
method: str,
params: List[str],
info_getter: Callable[[], DataPlaneInfo],
refresh: Callable[[str], Token],
):
"""Get and cache information required to query a Data Plane endpoint using the provided methods.

Returns a cached DataPlaneDetails if the details have already been fetched previously and are still valid.
If not, it uses the provided functions to fetch the details.

:param method: method name. Used to construct a unique key for the cache.
:param params: path params used in the "get" operation which uniquely determine the object. Used to construct a unique key for the cache.
:param info_getter: function which returns the DataPlaneInfo. It will only be called if the information is not already present in the cache.
:param refresh: function to refresh the token. It will only be called if the token is missing or expired.
"""
all_elements = params.copy()
all_elements.insert(0, method)
map_key = "/".join(all_elements)
info = self._data_plane_info.get(map_key)
if not info:
self._lock.acquire()
try:
info = self._data_plane_info.get(map_key)
if not info:
info = info_getter()
self._data_plane_info[map_key] = info
finally:
self._lock.release()

token = self._tokens.get(map_key)
if not token or not token.valid:
self._lock.acquire()
token = self._tokens.get(map_key)
try:
if not token or not token.valid:
token = refresh(info.authorization_details)
self._tokens[map_key] = token
finally:
self._lock.release()

return DataPlaneDetails(endpoint_url=info.endpoint_url, token=token)
9 changes: 7 additions & 2 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,16 @@ def __init__(
client_id: str,
client_secret: str = None,
redirect_url: str = None,
disable_async: bool = True,
):
self._token_endpoint = token_endpoint
self._client_id = client_id
self._client_secret = client_secret
self._redirect_url = redirect_url
super().__init__(token)
super().__init__(
token=token,
disable_async=disable_async,
)

def as_dict(self) -> dict:
return {"token": self.token().as_dict()}
Expand Down Expand Up @@ -708,9 +712,10 @@ class ClientCredentials(Refreshable):
scopes: List[str] = None
use_params: bool = False
use_header: bool = False
disable_async: bool = True

def __post_init__(self):
super().__init__()
super().__init__(disable_async=self.disable_async)

def refresh(self) -> Token:
params = {"grant_type": "client_credentials"}
Expand Down
50 changes: 29 additions & 21 deletions databricks/sdk/service/serving.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions tests/integration/test_data_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ def test_data_plane_token_source(ucws, env_or_skip):
dp_token = ts.token(info.endpoint_url, info.authorization_details)

assert dp_token.valid


def test_model_serving_data_plane(ucws, env_or_skip):
endpoint = env_or_skip("SERVING_ENDPOINT_NAME")
serving_endpoints = ucws.serving_endpoints_data_plane
response = serving_endpoints.query(name=endpoint, dataframe_records=[{"col": 1.0}])
assert response is not None
81 changes: 1 addition & 80 deletions tests/test_data_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from urllib import parse

from databricks.sdk import data_plane, oauth
from databricks.sdk.data_plane import DataPlaneService
from databricks.sdk.oauth import Token
from databricks.sdk.service.serving import DataPlaneInfo

cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
Expand Down Expand Up @@ -63,88 +61,11 @@ def token(self):

def test_token_source_get_token_existing(config):
another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True)
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(cp_token), disable_async=True)
token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token)

with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")

retrieve_token.assert_not_called()
assert result_token.access_token == another_token.access_token


## These tests are for the old implementation. #TODO: Remove after the new implementation is used

info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url")

token = Token(
access_token="token",
token_type="type",
expiry=datetime.now() + timedelta(hours=1),
)


class MockRefresher:

def __init__(self, expected: str):
self._expected = expected

def __call__(self, auth_details: str) -> Token:
assert self._expected == auth_details
return token


def throw_exception():
raise Exception("Expected value to be cached")


def test_not_cached():
data_plane = DataPlaneService()
res = data_plane.get_data_plane_details(
"method",
["params"],
lambda: info,
lambda a: MockRefresher(info.authorization_details).__call__(a),
)
assert res.endpoint_url == info.endpoint_url
assert res.token == token


def test_token_expired():
expired = Token(
access_token="expired",
token_type="type",
expiry=datetime.now() + timedelta(hours=-1),
)
data_plane = DataPlaneService()
data_plane._tokens["method/params"] = expired
res = data_plane.get_data_plane_details(
"method",
["params"],
lambda: info,
lambda a: MockRefresher(info.authorization_details).__call__(a),
)
assert res.endpoint_url == info.endpoint_url
assert res.token == token


def test_info_cached():
data_plane = DataPlaneService()
data_plane._data_plane_info["method/params"] = info
res = data_plane.get_data_plane_details(
"method",
["params"],
throw_exception,
lambda a: MockRefresher(info.authorization_details).__call__(a),
)
assert res.endpoint_url == info.endpoint_url
assert res.token == token


def test_token_cached():
data_plane = DataPlaneService()
data_plane._data_plane_info["method/params"] = info
data_plane._tokens["method/params"] = token
res = data_plane.get_data_plane_details("method", ["params"], throw_exception, throw_exception)
assert res.endpoint_url == info.endpoint_url
assert res.token == token
Loading