diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 7ac3c892b..3a1cb4b29 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,12 @@ ## Release v0.46.0 ### New Features and Improvements +* [Experimental] Add support for async token refresh ([#916](https://github.com/databricks/databricks-sdk-py/pull/916)). + This can be enabled with by setting the following setting: + ``` + export DATABRICKS_ENABLE_EXPERIMENTAL_ASYNC_TOKEN_REFRESH=1. + ``` + This feature and its setting are experimental and may be removed in future releases. ### Bug Fixes diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 48f59d48a..806d8c584 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -8,6 +8,7 @@ import databricks.sdk.service as service from databricks.sdk import azure from databricks.sdk.credentials_provider import CredentialsStrategy +from databricks.sdk.data_plane import DataPlaneTokenSource from databricks.sdk.mixins.compute import ClustersExt from databricks.sdk.mixins.files import DbfsExt, FilesExt from databricks.sdk.mixins.jobs import JobsExt @@ -285,8 +286,11 @@ def __init__( self._secrets = service.workspace.SecretsAPI(self._api_client) self._service_principals = service.iam.ServicePrincipalsAPI(self._api_client) self._serving_endpoints = serving_endpoints + serving_endpoints_data_plane_token_source = DataPlaneTokenSource( + self._config.host, self._config.oauth_token, not self._config.enable_experimental_async_token_refresh + ) self._serving_endpoints_data_plane = service.serving.ServingEndpointsDataPlaneAPI( - self._api_client, serving_endpoints + self._api_client, serving_endpoints, serving_endpoints_data_plane_token_source ) self._settings = service.settings.SettingsAPI(self._api_client) self._shares = service.sharing.SharesAPI(self._api_client) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 591aafc44..2a05cf6ba 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -95,6 +95,10 @@ class Config: max_connections_per_pool: int = ConfigAttribute() databricks_environment: Optional[DatabricksEnvironment] = None + enable_experimental_async_token_refresh: bool = ConfigAttribute( + env="DATABRICKS_ENABLE_EXPERIMENTAL_ASYNC_TOKEN_REFRESH" + ) + enable_experimental_files_api_client: bool = ConfigAttribute(env="DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT") files_api_client_download_max_total_recovers = None files_api_client_download_max_total_recovers_without_progressing = 1 diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 86acac86c..eac7c9697 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -191,6 +191,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]: token_url=oidc.token_endpoint, scopes=["all-apis"], use_header=True, + disable_async=not cfg.enable_experimental_async_token_refresh, ) def inner() -> Dict[str, str]: @@ -290,6 +291,7 @@ def token_source_for(resource: str) -> TokenSource: token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", endpoint_params={"resource": resource}, use_params=True, + disable_async=not cfg.enable_experimental_async_token_refresh, ) _ensure_host_present(cfg, token_source_for) @@ -355,6 +357,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", endpoint_params=params, use_params=True, + disable_async=not cfg.enable_experimental_async_token_refresh, ) def refreshed_headers() -> Dict[str, str]: @@ -458,8 +461,9 @@ def __init__( token_type_field: str, access_token_field: str, expiry_field: str, + disable_async: bool = True, ): - super().__init__() + super().__init__(disable_async=disable_async) self._cmd = cmd self._token_type_field = token_type_field self._access_token_field = access_token_field @@ -690,6 +694,7 @@ def __init__(self, cfg: "Config"): token_type_field="token_type", access_token_field="access_token", expiry_field="expiry", + disable_async=not cfg.enable_experimental_async_token_refresh, ) @staticmethod diff --git a/databricks/sdk/data_plane.py b/databricks/sdk/data_plane.py index 3c059ecf2..aa772edcc 100644 --- a/databricks/sdk/data_plane.py +++ b/databricks/sdk/data_plane.py @@ -2,7 +2,7 @@ import threading from dataclasses import dataclass -from typing import Callable, List, Optional +from typing import Callable, Optional from urllib import parse from databricks.sdk import oauth @@ -88,61 +88,3 @@ class DataPlaneDetails: """URL used to query the endpoint through the DataPlane.""" token: Token """Token to query the DataPlane endpoint.""" - - -## Old implementation. #TODO: Remove after the new implementation is used - - -class DataPlaneService: - """Helper class to fetch and manage DataPlane details.""" - - from .service.serving import DataPlaneInfo - - def __init__(self): - self._data_plane_info = {} - self._tokens = {} - self._lock = threading.Lock() - - def get_data_plane_details( - self, - method: str, - params: List[str], - info_getter: Callable[[], DataPlaneInfo], - refresh: Callable[[str], Token], - ): - """Get and cache information required to query a Data Plane endpoint using the provided methods. - - Returns a cached DataPlaneDetails if the details have already been fetched previously and are still valid. - If not, it uses the provided functions to fetch the details. - - :param method: method name. Used to construct a unique key for the cache. - :param params: path params used in the "get" operation which uniquely determine the object. Used to construct a unique key for the cache. - :param info_getter: function which returns the DataPlaneInfo. It will only be called if the information is not already present in the cache. - :param refresh: function to refresh the token. It will only be called if the token is missing or expired. - """ - all_elements = params.copy() - all_elements.insert(0, method) - map_key = "/".join(all_elements) - info = self._data_plane_info.get(map_key) - if not info: - self._lock.acquire() - try: - info = self._data_plane_info.get(map_key) - if not info: - info = info_getter() - self._data_plane_info[map_key] = info - finally: - self._lock.release() - - token = self._tokens.get(map_key) - if not token or not token.valid: - self._lock.acquire() - token = self._tokens.get(map_key) - try: - if not token or not token.valid: - token = refresh(info.authorization_details) - self._tokens[map_key] = token - finally: - self._lock.release() - - return DataPlaneDetails(endpoint_url=info.endpoint_url, token=token) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index d2df2f0f5..48b218f08 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -426,12 +426,16 @@ def __init__( client_id: str, client_secret: str = None, redirect_url: str = None, + disable_async: bool = True, ): self._token_endpoint = token_endpoint self._client_id = client_id self._client_secret = client_secret self._redirect_url = redirect_url - super().__init__(token) + super().__init__( + token=token, + disable_async=disable_async, + ) def as_dict(self) -> dict: return {"token": self.token().as_dict()} @@ -708,9 +712,10 @@ class ClientCredentials(Refreshable): scopes: List[str] = None use_params: bool = False use_header: bool = False + disable_async: bool = True def __post_init__(self): - super().__init__() + super().__init__(disable_async=self.disable_async) def refresh(self) -> Token: params = {"grant_type": "client_credentials"} diff --git a/databricks/sdk/service/serving.py b/databricks/sdk/service/serving.py index fccad9b79..ce65a795f 100755 --- a/databricks/sdk/service/serving.py +++ b/databricks/sdk/service/serving.py @@ -4,6 +4,7 @@ import logging import random +import threading import time from dataclasses import dataclass from datetime import timedelta @@ -4657,12 +4658,31 @@ class ServingEndpointsDataPlaneAPI: """Serving endpoints DataPlane provides a set of operations to interact with data plane endpoints for Serving endpoints service.""" - def __init__(self, api_client, control_plane): + def __init__(self, api_client, control_plane_service, dpts): self._api = api_client - self._control_plane = control_plane - from ..data_plane import DataPlaneService - - self._data_plane_service = DataPlaneService() + self._lock = threading.Lock() + self._control_plane_service = control_plane_service + self._dpts = dpts + self._data_plane_details = {} + + def _data_plane_info_query(self, name: str) -> DataPlaneInfo: + key = "query" + "/".join( + [ + str(name), + ] + ) + with self._lock: + if key in self._data_plane_details: + return self._data_plane_details[key] + response = self._control_plane_service.get( + name=name, + ) + if response.data_plane_info is None: + raise Exception("Resource does not support direct Data Plane access") + result = response.data_plane_info.query_info + with self._lock: + self._data_plane_details[key] = result + return result def query( self, @@ -4757,22 +4777,10 @@ def query( body["stream"] = stream if temperature is not None: body["temperature"] = temperature - - def info_getter(): - response = self._control_plane.get( - name=name, - ) - if response.data_plane_info is None: - raise Exception("Resource does not support direct Data Plane access") - return response.data_plane_info.query_info - - get_params = [ - name, - ] - data_plane_details = self._data_plane_service.get_data_plane_details( - "query", get_params, info_getter, self._api.get_oauth_token + data_plane_info = self._data_plane_info_query( + name=name, ) - token = data_plane_details.token + token = self._dpts.token(data_plane_info.endpoint_url, data_plane_info.authorization_details) def auth(r: requests.PreparedRequest) -> requests.PreparedRequest: authorization = f"{token.token_type} {token.access_token}" @@ -4788,7 +4796,7 @@ def auth(r: requests.PreparedRequest) -> requests.PreparedRequest: ] res = self._api.do( "POST", - url=data_plane_details.endpoint_url, + url=data_plane_info.endpoint_url, body=body, headers=headers, response_headers=response_headers, diff --git a/tests/integration/test_data_plane.py b/tests/integration/test_data_plane.py index 0062a7ed0..338366667 100644 --- a/tests/integration/test_data_plane.py +++ b/tests/integration/test_data_plane.py @@ -13,3 +13,10 @@ def test_data_plane_token_source(ucws, env_or_skip): dp_token = ts.token(info.endpoint_url, info.authorization_details) assert dp_token.valid + + +def test_model_serving_data_plane(ucws, env_or_skip): + endpoint = env_or_skip("SERVING_ENDPOINT_NAME") + serving_endpoints = ucws.serving_endpoints_data_plane + response = serving_endpoints.query(name=endpoint, dataframe_records=[{"col": 1.0}]) + assert response is not None diff --git a/tests/test_data_plane.py b/tests/test_data_plane.py index 54ace9ba7..d5956be57 100644 --- a/tests/test_data_plane.py +++ b/tests/test_data_plane.py @@ -3,9 +3,7 @@ from urllib import parse from databricks.sdk import data_plane, oauth -from databricks.sdk.data_plane import DataPlaneService from databricks.sdk.oauth import Token -from databricks.sdk.service.serving import DataPlaneInfo cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) @@ -63,7 +61,7 @@ def token(self): def test_token_source_get_token_existing(config): another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) - token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True) + token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(cp_token), disable_async=True) token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token) with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: @@ -71,80 +69,3 @@ def test_token_source_get_token_existing(config): retrieve_token.assert_not_called() assert result_token.access_token == another_token.access_token - - -## These tests are for the old implementation. #TODO: Remove after the new implementation is used - -info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url") - -token = Token( - access_token="token", - token_type="type", - expiry=datetime.now() + timedelta(hours=1), -) - - -class MockRefresher: - - def __init__(self, expected: str): - self._expected = expected - - def __call__(self, auth_details: str) -> Token: - assert self._expected == auth_details - return token - - -def throw_exception(): - raise Exception("Expected value to be cached") - - -def test_not_cached(): - data_plane = DataPlaneService() - res = data_plane.get_data_plane_details( - "method", - ["params"], - lambda: info, - lambda a: MockRefresher(info.authorization_details).__call__(a), - ) - assert res.endpoint_url == info.endpoint_url - assert res.token == token - - -def test_token_expired(): - expired = Token( - access_token="expired", - token_type="type", - expiry=datetime.now() + timedelta(hours=-1), - ) - data_plane = DataPlaneService() - data_plane._tokens["method/params"] = expired - res = data_plane.get_data_plane_details( - "method", - ["params"], - lambda: info, - lambda a: MockRefresher(info.authorization_details).__call__(a), - ) - assert res.endpoint_url == info.endpoint_url - assert res.token == token - - -def test_info_cached(): - data_plane = DataPlaneService() - data_plane._data_plane_info["method/params"] = info - res = data_plane.get_data_plane_details( - "method", - ["params"], - throw_exception, - lambda a: MockRefresher(info.authorization_details).__call__(a), - ) - assert res.endpoint_url == info.endpoint_url - assert res.token == token - - -def test_token_cached(): - data_plane = DataPlaneService() - data_plane._data_plane_info["method/params"] = info - data_plane._tokens["method/params"] = token - res = data_plane.get_data_plane_details("method", ["params"], throw_exception, throw_exception) - assert res.endpoint_url == info.endpoint_url - assert res.token == token