diff --git a/databricks/sdk/data_plane.py b/databricks/sdk/data_plane.py index 4da8006ec..3c059ecf2 100644 --- a/databricks/sdk/data_plane.py +++ b/databricks/sdk/data_plane.py @@ -1,9 +1,82 @@ +from __future__ import annotations + import threading from dataclasses import dataclass -from typing import Callable, List +from typing import Callable, List, Optional +from urllib import parse +from databricks.sdk import oauth from databricks.sdk.oauth import Token +URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" +JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" +OIDC_TOKEN_PATH = "/oidc/v1/token" + + +class DataPlaneTokenSource: + """ + EXPERIMENTAL Manages token sources for multiple DataPlane endpoints. + """ + + # TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled. + def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], disable_async: Optional[bool] = True): + self._cpts = cpts + self._token_exchange_host = token_exchange_host + self._token_sources = {} + self._disable_async = disable_async + self._lock = threading.Lock() + + def token(self, endpoint, auth_details): + key = f"{endpoint}:{auth_details}" + + # First, try to read without acquiring the lock to avoid contention. + # Reads are atomic, so this is safe. + token_source = self._token_sources.get(key) + if token_source: + return token_source.token() + + # If token_source is not found, acquire the lock and check again. + with self._lock: + # Another thread might have created it while we were waiting for the lock. + token_source = self._token_sources.get(key) + if not token_source: + token_source = DataPlaneEndpointTokenSource( + self._token_exchange_host, self._cpts, auth_details, self._disable_async + ) + self._token_sources[key] = token_source + + return token_source.token() + + +class DataPlaneEndpointTokenSource(oauth.Refreshable): + """ + EXPERIMENTAL A token source for a specific DataPlane endpoint. + """ + + def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str, disable_async: bool): + super().__init__(disable_async=disable_async) + self._auth_details = auth_details + self._cpts = cpts + self._token_exchange_host = token_exchange_host + + def refresh(self) -> Token: + control_plane_token = self._cpts() + headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE} + params = parse.urlencode( + { + "grant_type": JWT_BEARER_GRANT_TYPE, + "authorization_details": self._auth_details, + "assertion": control_plane_token.access_token, + } + ) + return oauth.retrieve_token( + client_id="", + client_secret="", + token_url=self._token_exchange_host + OIDC_TOKEN_PATH, + params=params, + headers=headers, + ) + @dataclass class DataPlaneDetails: @@ -17,6 +90,9 @@ class DataPlaneDetails: """Token to query the DataPlane endpoint.""" +## Old implementation. #TODO: Remove after the new implementation is used + + class DataPlaneService: """Helper class to fetch and manage DataPlane details.""" diff --git a/tests/integration/test_data_plane.py b/tests/integration/test_data_plane.py new file mode 100644 index 000000000..0062a7ed0 --- /dev/null +++ b/tests/integration/test_data_plane.py @@ -0,0 +1,15 @@ +from databricks.sdk.data_plane import DataPlaneTokenSource + + +def test_data_plane_token_source(ucws, env_or_skip): + endpoint = env_or_skip("SERVING_ENDPOINT_NAME") + serving_endpoint = ucws.serving_endpoints.get(endpoint) + assert serving_endpoint.data_plane_info is not None + assert serving_endpoint.data_plane_info.query_info is not None + + info = serving_endpoint.data_plane_info.query_info + + ts = DataPlaneTokenSource(ucws.config.host, ucws._config.oauth_token) + dp_token = ts.token(info.endpoint_url, info.authorization_details) + + assert dp_token.valid diff --git a/tests/test_data_plane.py b/tests/test_data_plane.py index d7721f014..54ace9ba7 100644 --- a/tests/test_data_plane.py +++ b/tests/test_data_plane.py @@ -1,9 +1,80 @@ from datetime import datetime, timedelta +from unittest.mock import patch +from urllib import parse +from databricks.sdk import data_plane, oauth from databricks.sdk.data_plane import DataPlaneService from databricks.sdk.oauth import Token from databricks.sdk.service.serving import DataPlaneInfo +cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) +dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) + + +def success_callable(token: oauth.Token): + + def success() -> oauth.Token: + return token + + return success + + +def test_endpoint_token_source_get_token(config): + token_source = data_plane.DataPlaneEndpointTokenSource( + config.host, success_callable(cp_token), "authDetails", disable_async=True + ) + + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: + token_source.token() + + retrieve_token.assert_called_once() + args, kwargs = retrieve_token.call_args + + assert kwargs["token_url"] == config.host + "/oidc/v1/token" + assert kwargs["params"] == parse.urlencode( + { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "authorization_details": "authDetails", + "assertion": cp_token.access_token, + } + ) + assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"} + + +def test_token_source_get_token_not_existing(config): + token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(cp_token), disable_async=True) + + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: + result_token = token_source.token(endpoint="endpoint", auth_details="authDetails") + + retrieve_token.assert_called_once() + assert result_token.access_token == dp_token.access_token + assert "endpoint:authDetails" in token_source._token_sources + + +class MockEndpointTokenSource: + + def __init__(self, token: oauth.Token): + self._token = token + + def token(self): + return self._token + + +def test_token_source_get_token_existing(config): + another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) + token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True) + token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token) + + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: + result_token = token_source.token(endpoint="endpoint", auth_details="authDetails") + + retrieve_token.assert_not_called() + assert result_token.access_token == another_token.access_token + + +## These tests are for the old implementation. #TODO: Remove after the new implementation is used + info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url") token = Token(