diff --git a/CHANGELOG.md b/CHANGELOG.md index 55aeb386b..da265c621 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Version changelog +## [Unreleased] + +### New Features and Improvements + +* Add support for unified hosts, i.e. hosts that support both workspace-level and account-level operations +* Deprecate `Config.is_account_client`, which will not work for unified hosts, and replace it with `Config.host_type()` and `Config.config_type()` methods +* Add validation in `WorkspaceClient` and `AccountClient` constructors to ensure configs are appropriate for the client type + ## Release v0.71.0 (2025-10-30) ### Bug Fixes diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 0d285ccda..437e318f7 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -247,6 +247,15 @@ def __init__( product_version=product_version, token_audience=token_audience, ) + + # Validate that the config is appropriate for a WorkspaceClient + from .config import HostType + host_type = config.host_type() + if host_type == HostType.ACCOUNT_HOST: + raise ValueError("invalid Databricks Workspace configuration - host is not a workspace host") + if host_type == HostType.UNIFIED_HOST and not config.workspace_id: + raise ValueError("workspace_id must be set when using WorkspaceClient with unified host") + self._config = config.copy() self._dbutils = _make_dbutils(self._config) self._api_client = client.ApiClient(self._config) @@ -1081,6 +1090,16 @@ def __init__( product_version=product_version, token_audience=token_audience, ) + + # Validate that the config is appropriate for an AccountClient + from .config import HostType + if not config.account_id or config.host_type() == HostType.WORKSPACE_HOST: + raise ValueError("invalid Databricks Account configuration - host incorrect or account_id missing") + # WorkspaceId must NOT be present in a config used with account client because + # unified hosts route calls based on the presence of the X-Databricks-Org-Id header. + if config.workspace_id: + raise ValueError("workspace_id must not be set when using AccountClient") + self._config = config.copy() self._api_client = client.ApiClient(self._config) self._access_control = pkg_iam.AccountAccessControlAPI(self._api_client) @@ -1333,6 +1352,7 @@ def get_workspace_client(self, workspace: Workspace) -> WorkspaceClient: config = self._config.deep_copy() config.host = config.environment.deployment_url(workspace.deployment_name) config.azure_workspace_resource_id = azure.get_azure_resource_id(workspace) + config.workspace_id = str(workspace.workspace_id) config.account_id = None config.init_auth() return WorkspaceClient(config=config) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 879ba64ec..4207e75b7 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -6,6 +6,7 @@ import pathlib import sys import urllib.parse +from enum import Enum from typing import Dict, Iterable, List, Optional import requests @@ -18,12 +19,26 @@ from .environments import (ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname) from .oauth import (OidcEndpoints, Token, get_account_endpoints, - get_azure_entra_id_workspace_endpoints, + get_azure_entra_id_workspace_endpoints, get_unified_endpoints, get_workspace_endpoints) logger = logging.getLogger("databricks.sdk") +class HostType(Enum): + """Represents the type of API the configured host supports.""" + WORKSPACE_HOST = "WORKSPACE_HOST" # Supports only workspace-level APIs + ACCOUNT_HOST = "ACCOUNT_HOST" # Supports only account-level APIs + UNIFIED_HOST = "UNIFIED_HOST" # Supports both workspace-level and account-level APIs + + +class ConfigType(Enum): + """Represents the type of API this config is valid for.""" + WORKSPACE_CONFIG = "WORKSPACE_CONFIG" # Valid for workspace-level API requests + ACCOUNT_CONFIG = "ACCOUNT_CONFIG" # Valid for account-level API requests + INVALID_CONFIG = "INVALID_CONFIG" # Not valid for either workspace-level or account-level APIs + + class ConfigAttribute: """Configuration attribute metadata and descriptor protocols.""" @@ -62,6 +77,9 @@ class Config: host: str = ConfigAttribute(env="DATABRICKS_HOST") account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") + # Databricks Workspace ID for Workspace clients when working with unified hosts + workspace_id: str = ConfigAttribute(env="DATABRICKS_WORKSPACE_ID") + # PAT token. token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) @@ -108,6 +126,9 @@ class Config: max_connections_per_pool: int = ConfigAttribute() databricks_environment: Optional[DatabricksEnvironment] = None + # Marker for unified hosts. Will be redundant once we can recognize unified hosts by their hostname. + experimental_is_unified_host: bool = ConfigAttribute(env="DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST") + disable_async_token_refresh: bool = ConfigAttribute(env="DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH") disable_experimental_files_api_client: bool = ConfigAttribute( @@ -288,7 +309,13 @@ def parse_dsn(dsn: str) -> "Config": def authenticate(self) -> Dict[str, str]: """Returns a list of fresh authentication headers""" - return self._header_factory() + headers = self._header_factory() + # Unified hosts use X-Databricks-Org-Id header to determine which workspace to route the request to. + # The header must not be set for account-level API requests, otherwise the request will fail. + # This relies on the assumption that workspace_id is only set for workspace client configs. + if self.host_type() == HostType.UNIFIED_HOST and self.workspace_id: + headers["X-Databricks-Org-Id"] = self.workspace_id + return headers def as_dict(self) -> dict: return self._inner @@ -337,10 +364,59 @@ def is_aws(self) -> bool: @property def is_account_client(self) -> bool: + """Returns true if client is configured for Accounts API. + + Deprecated: Use host_type() if possible, or config_type() if necessary. + Raises RuntimeError if the config has the unified host flag set. + """ + if self.experimental_is_unified_host: + raise RuntimeError("is_account_client cannot be used with unified hosts; use host_type() instead") if not self.host: return False return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.") + def host_type(self) -> HostType: + """Returns the type of host that the client is configured for.""" + if self.experimental_is_unified_host: + return HostType.UNIFIED_HOST + + if not self.host: + return HostType.WORKSPACE_HOST + + accounts_prefixes = [ + "https://accounts.", + "https://accounts-dod.", + ] + for prefix in accounts_prefixes: + if self.host.startswith(prefix): + return HostType.ACCOUNT_HOST + + return HostType.WORKSPACE_HOST + + def config_type(self) -> ConfigType: + """Returns the type of config that the client is configured for. + + Returns InvalidConfig if the config is invalid. + Use of this function should be avoided where possible, because we plan + to remove WorkspaceClient and AccountClient in favor of a single unified + client in the future. + """ + host_type = self.host_type() + + if host_type == HostType.ACCOUNT_HOST: + return ConfigType.ACCOUNT_CONFIG + elif host_type == HostType.WORKSPACE_HOST: + return ConfigType.WORKSPACE_CONFIG + elif host_type == HostType.UNIFIED_HOST: + if not self.account_id: + # All unified host configs must have an account ID + return ConfigType.INVALID_CONFIG + if self.workspace_id: + return ConfigType.WORKSPACE_CONFIG + return ConfigType.ACCOUNT_CONFIG + else: + return ConfigType.INVALID_CONFIG + @property def arm_environment(self) -> AzureEnvironment: return self.environment.azure_environment @@ -391,9 +467,15 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]: return None if self.is_azure and self.azure_client_id: return get_azure_entra_id_workspace_endpoints(self.host) - if self.is_account_client and self.account_id: + + host_type = self.host_type() + if host_type == HostType.ACCOUNT_HOST and self.account_id: return get_account_endpoints(self.host, self.account_id) - return get_workspace_endpoints(self.host) + elif host_type == HostType.UNIFIED_HOST and self.account_id: + return get_unified_endpoints(self.host, self.account_id) + elif host_type == HostType.WORKSPACE_HOST: + return get_workspace_endpoints(self.host) + return None def debug_string(self) -> str: """Returns log-friendly representation of configured attributes""" diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 022482370..d6926c37f 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -382,10 +382,11 @@ def _oidc_credentials_provider( return None # Determine the audience for token exchange + from .config import ConfigType audience = cfg.token_audience - if audience is None and cfg.is_account_client: + if audience is None and cfg.config_type() != ConfigType.WORKSPACE_CONFIG: audience = cfg.account_id - if audience is None and not cfg.is_account_client: + if audience is None and cfg.config_type() == ConfigType.WORKSPACE_CONFIG: audience = cfg.oidc_endpoints.token_endpoint # Try to get an OIDC token. If no supplier returns a token, we cannot use this authentication mode. @@ -537,9 +538,10 @@ def token() -> oauth.Token: return credentials.token def refreshed_headers() -> Dict[str, str]: + from .config import ConfigType credentials.refresh(request) headers = {"Authorization": f"Bearer {credentials.token}"} - if cfg.is_account_client: + if cfg.config_type() != ConfigType.WORKSPACE_CONFIG: gcp_credentials.refresh(request) headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token return headers @@ -578,9 +580,10 @@ def token() -> oauth.Token: return id_creds.token def refreshed_headers() -> Dict[str, str]: + from .config import ConfigType id_creds.refresh(request) headers = {"Authorization": f"Bearer {id_creds.token}"} - if cfg.is_account_client: + if cfg.config_type() != ConfigType.WORKSPACE_CONFIG: gcp_impersonated_credentials.refresh(request) headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token return headers @@ -801,8 +804,9 @@ class DatabricksCliTokenSource(CliTokenSource): """Obtain the token granted by `databricks auth login` CLI command""" def __init__(self, cfg: "Config"): + from .config import ConfigType args = ["auth", "token", "--host", cfg.host] - if cfg.is_account_client: + if cfg.config_type() != ConfigType.WORKSPACE_CONFIG: args += ["--account-id", cfg.account_id] cli_path = cfg.databricks_cli_path diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index f18f0cd51..dbf8e86c5 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -382,6 +382,19 @@ def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _Bas return OidcEndpoints.from_dict(resp) +def get_unified_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a unified host. + :param host: The Databricks unified host. + :param account_id: The account ID. + :return: The unified host's OIDC endpoints. + """ + host = _fix_host_if_needed(host) + oidc = f"{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server" + resp = client.do("GET", oidc) + return OidcEndpoints.from_dict(resp) + + def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: """ Get the OIDC endpoints for a given workspace. diff --git a/tests/test_config.py b/tests/test_config.py index 59fbf8712..f65fabcea 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,8 +7,8 @@ import pytest -from databricks.sdk import oauth, useragent -from databricks.sdk.config import Config, with_product, with_user_agent_extra +from databricks.sdk import AccountClient, WorkspaceClient, oauth, useragent +from databricks.sdk.config import Config, ConfigType, HostType, with_product, with_user_agent_extra from databricks.sdk.version import __version__ from .conftest import noop_credentials, set_az_path @@ -260,3 +260,268 @@ def test_oauth_token_reuses_existing_provider(mocker): # Both calls should work and use the same provider instance assert token1 == token2 == mock_token assert mock_oauth_provider.oauth_token.call_count == 2 + + +def test_host_type_aws_account(): + """Test that host_type returns ACCOUNT_HOST for AWS accounts host.""" + config = Config( + host="https://accounts.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + credentials_strategy=noop_credentials, + ) + assert config.host_type() == HostType.ACCOUNT_HOST + + +def test_host_type_aws_dod_account(): + """Test that host_type returns ACCOUNT_HOST for AWS DoD accounts host.""" + config = Config( + host="https://accounts-dod.cloud.databricks.us", + account_id="123e4567-e89b-12d3-a456-426614174000", + credentials_strategy=noop_credentials, + ) + assert config.host_type() == HostType.ACCOUNT_HOST + + +def test_host_type_aws_workspace(): + """Test that host_type returns WORKSPACE_HOST for AWS workspace host.""" + config = Config( + host="https://my-workspace.cloud.databricks.us", + account_id="123e4567-e89b-12d3-a456-426614174000", + credentials_strategy=noop_credentials, + ) + assert config.host_type() == HostType.WORKSPACE_HOST + + +def test_host_type_unified(): + """Test that host_type returns UNIFIED_HOST for unified host.""" + config = Config( + host="https://unified.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + assert config.host_type() == HostType.UNIFIED_HOST + + +def test_is_account_client_raises_on_unified_host(): + """Test that is_account_client raises RuntimeError on unified host.""" + config = Config( + host="https://unified.cloud.databricks.com", + account_id="test-account", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + with pytest.raises(RuntimeError) as exc_info: + _ = config.is_account_client + assert "is_account_client cannot be used with unified hosts" in str(exc_info.value) + + +def test_config_type_account(): + """Test that config_type returns ACCOUNT_CONFIG for account host.""" + config = Config( + host="https://accounts.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + credentials_strategy=noop_credentials, + ) + assert config.config_type() == ConfigType.ACCOUNT_CONFIG + + +def test_config_type_workspace(): + """Test that config_type returns WORKSPACE_CONFIG for workspace host.""" + config = Config( + host="https://my-workspace.cloud.databricks.us", + credentials_strategy=noop_credentials, + ) + assert config.config_type() == ConfigType.WORKSPACE_CONFIG + + +def test_config_type_unified_with_workspace_id(): + """Test that config_type returns WORKSPACE_CONFIG for unified host with workspace_id.""" + config = Config( + host="https://unified.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + workspace_id="12345", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + assert config.config_type() == ConfigType.WORKSPACE_CONFIG + + +def test_config_type_unified_without_workspace_id(): + """Test that config_type returns ACCOUNT_CONFIG for unified host without workspace_id.""" + config = Config( + host="https://unified.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + assert config.config_type() == ConfigType.ACCOUNT_CONFIG + + +def test_config_type_unified_invalid_without_account_id(): + """Test that config_type returns INVALID_CONFIG for unified host without account_id.""" + config = Config( + host="https://unified.cloud.databricks.com", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + assert config.config_type() == ConfigType.INVALID_CONFIG + + +def test_oidc_endpoints_unified(requests_mock): + """Test OIDC endpoints for unified host.""" + mock = requests_mock.get( + "https://unified.cloud.databricks.com/oidc/accounts/abc/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://unified.cloud.databricks.com/oidc/accounts/abc/v1/authorize", + "token_endpoint": "https://unified.cloud.databricks.com/oidc/accounts/abc/v1/token", + }, + ) + config = Config( + host="https://unified.cloud.databricks.com", + account_id="abc", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + endpoints = config.oidc_endpoints + assert endpoints is not None + assert endpoints.authorization_endpoint == "https://unified.cloud.databricks.com/oidc/accounts/abc/v1/authorize" + assert endpoints.token_endpoint == "https://unified.cloud.databricks.com/oidc/accounts/abc/v1/token" + assert mock.called_once + + +def test_authenticate_adds_org_id_header_for_unified_workspace(): + """Test that authenticate() adds X-Databricks-Org-Id header for unified workspace config.""" + config = Config( + host="https://unified.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + workspace_id="12345", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + headers = config.authenticate() + assert "X-Databricks-Org-Id" in headers + assert headers["X-Databricks-Org-Id"] == "12345" + + +def test_authenticate_no_org_id_header_for_unified_account(): + """Test that authenticate() does not add X-Databricks-Org-Id header for unified account config.""" + config = Config( + host="https://unified.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + headers = config.authenticate() + assert "X-Databricks-Org-Id" not in headers + + +def test_authenticate_no_org_id_header_for_workspace_host(): + """Test that authenticate() does not add X-Databricks-Org-Id header for non-unified workspace host.""" + config = Config( + host="https://my-workspace.cloud.databricks.us", + credentials_strategy=noop_credentials, + ) + headers = config.authenticate() + assert "X-Databricks-Org-Id" not in headers + + +def test_workspace_client_rejects_account_host(): + """Test that WorkspaceClient raises ValueError for account host.""" + with pytest.raises(ValueError) as exc_info: + WorkspaceClient( + host="https://accounts.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + credentials_strategy=noop_credentials, + ) + assert "invalid Databricks Workspace configuration - host is not a workspace host" in str(exc_info.value) + + +def test_workspace_client_rejects_unified_host_without_workspace_id(): + """Test that WorkspaceClient raises ValueError for unified host without workspace_id.""" + with pytest.raises(ValueError) as exc_info: + WorkspaceClient( + host="https://unified.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + assert "workspace_id must be set when using WorkspaceClient with unified host" in str(exc_info.value) + + +def test_workspace_client_accepts_unified_host_with_workspace_id(): + """Test that WorkspaceClient accepts unified host with workspace_id.""" + client = WorkspaceClient( + host="https://unified.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + workspace_id="12345", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + assert client is not None + assert client.config.workspace_id == "12345" + + +def test_workspace_client_accepts_workspace_host(): + """Test that WorkspaceClient accepts workspace host.""" + client = WorkspaceClient( + host="https://my-workspace.cloud.databricks.us", + credentials_strategy=noop_credentials, + ) + assert client is not None + + +def test_account_client_rejects_workspace_host(): + """Test that AccountClient raises ValueError for workspace host.""" + with pytest.raises(ValueError) as exc_info: + AccountClient( + host="https://my-workspace.cloud.databricks.us", + account_id="123e4567-e89b-12d3-a456-426614174000", + credentials_strategy=noop_credentials, + ) + assert "invalid Databricks Account configuration - host incorrect or account_id missing" in str(exc_info.value) + + +def test_account_client_rejects_missing_account_id(): + """Test that AccountClient raises ValueError when account_id is missing.""" + with pytest.raises(ValueError) as exc_info: + AccountClient( + host="https://accounts.cloud.databricks.com", + credentials_strategy=noop_credentials, + ) + assert "invalid Databricks Account configuration - host incorrect or account_id missing" in str(exc_info.value) + + +def test_account_client_rejects_workspace_id(): + """Test that AccountClient raises ValueError when workspace_id is set.""" + with pytest.raises(ValueError) as exc_info: + AccountClient( + host="https://accounts.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + workspace_id="12345", + credentials_strategy=noop_credentials, + ) + assert "workspace_id must not be set when using AccountClient" in str(exc_info.value) + + +def test_account_client_accepts_account_host(): + """Test that AccountClient accepts account host.""" + client = AccountClient( + host="https://accounts.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + credentials_strategy=noop_credentials, + ) + assert client is not None + assert client.config.account_id == "123e4567-e89b-12d3-a456-426614174000" + + +def test_account_client_accepts_unified_host(): + """Test that AccountClient accepts unified host without workspace_id.""" + client = AccountClient( + host="https://unified.cloud.databricks.com", + account_id="123e4567-e89b-12d3-a456-426614174000", + experimental_is_unified_host=True, + credentials_strategy=noop_credentials, + ) + assert client is not None + assert client.config.account_id == "123e4567-e89b-12d3-a456-426614174000"