diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index dcf680a8c..e4485a10a 100644 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -253,7 +253,7 @@ def __init__( product=product, product_version=product_version, token_audience=token_audience, - scopes=" ".join(scopes) if scopes else None, + scopes=scopes, authorization_details=( json.dumps([detail.as_dict() for detail in authorization_details]) if authorization_details diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index fa61bdbb9..bfc241855 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -46,11 +46,13 @@ class ConfigAttribute: # name and transform are discovered from Config.__new__ name: str = None transform: type = str + _custom_transform = None - def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): + def __init__(self, env: str = None, auth: str = None, sensitive: bool = False, transform=None): self.env = env self.auth = auth self.sensitive = sensitive + self._custom_transform = transform def __get__(self, cfg: "Config", owner): if not cfg: @@ -64,6 +66,19 @@ def __repr__(self) -> str: return f"" +def _parse_scopes(value): + """Parse scopes into a deduplicated, sorted list.""" + if value is None: + return None + if isinstance(value, list): + result = sorted(set(s for s in value if s)) + return result if result else None + if isinstance(value, str): + parsed = sorted(set(s.strip() for s in value.split(",") if s.strip())) + return parsed if parsed else None + return None + + def with_product(product: str, product_version: str): """[INTERNAL API] Change the product name and version used in the User-Agent header.""" useragent.with_product(product, product_version) @@ -133,10 +148,14 @@ class Config: disable_experimental_files_api_client: bool = ConfigAttribute( env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT" ) - # TODO: Expose these via environment variables too. - scopes: str = ConfigAttribute() + + scopes: list = ConfigAttribute(transform=_parse_scopes) authorization_details: str = ConfigAttribute() + # disable_oauth_refresh_token controls whether a refresh token should be requested + # during the U2M authentication flow (default to false). + disable_oauth_refresh_token: bool = ConfigAttribute(env="DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN") + files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB # When downloading a file, the maximum number of attempts to retry downloading the whole file. Default is no limit. @@ -553,7 +572,7 @@ def attributes(cls) -> Iterable[ConfigAttribute]: if type(v) != ConfigAttribute: continue v.name = name - v.transform = anno.get(name, str) + v.transform = v._custom_transform if v._custom_transform else anno.get(name, str) attrs.append(v) cls._attributes = attrs return cls._attributes @@ -685,6 +704,21 @@ def _init_product(self, product, product_version): else: self._product_info = None + def get_scopes(self) -> list: + """Get OAuth scopes with proper defaulting. + + Returns ["all-apis"] if no scopes configured. + This is the single source of truth for scope defaulting across all OAuth methods. + """ + return self.scopes if self.scopes else ["all-apis"] + + def get_scopes_as_string(self) -> str: + """Get OAuth scopes as a space-separated string. + + Returns "all-apis" if no scopes configured. + """ + return " ".join(self.get_scopes()) + def __repr__(self): return f"<{self.debug_string()}>" diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 926c50a05..b3e324bcc 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -198,7 +198,7 @@ def get_notebook_pat_token() -> Optional[str]: token_source = oauth.PATOAuthTokenExchange( get_original_token=get_notebook_pat_token, host=cfg.host, - scopes=cfg.scopes, + scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, ) @@ -225,7 +225,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]: client_id=cfg.client_id, client_secret=cfg.client_secret, token_url=oidc.token_endpoint, - scopes=cfg.scopes or "all-apis", + scopes=cfg.get_scopes_as_string(), use_header=True, disable_async=cfg.disable_async_token_refresh, authorization_details=cfg.authorization_details, @@ -256,6 +256,11 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: if not client_id: client_id = "databricks-cli" + scopes = cfg.get_scopes() + if not cfg.disable_oauth_refresh_token: + if "offline_access" not in scopes: + scopes = scopes + ["offline_access"] + # Load cached credentials from disk if they exist. Note that these are # local to the Python SDK and not reused by other SDKs. oidc_endpoints = cfg.oidc_endpoints @@ -266,6 +271,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: client_id=client_id, client_secret=client_secret, redirect_url=redirect_url, + scopes=scopes, ) credentials = token_cache.load() if credentials: @@ -284,6 +290,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: client_id=client_id, redirect_url=redirect_url, client_secret=client_secret, + scopes=scopes, ) consent = oauth_client.initiate_consent() if not consent: @@ -329,7 +336,7 @@ def token_source_for(resource: str) -> oauth.TokenSource: endpoint_params={"resource": resource}, use_params=True, disable_async=cfg.disable_async_token_refresh, - scopes=cfg.scopes, + scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, ) @@ -387,6 +394,7 @@ def oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optio account_id=cfg.account_id, id_token_source=id_token_source, disable_async=cfg.disable_async_token_refresh, + scopes=cfg.get_scopes_as_string(), ) def refreshed_headers() -> Dict[str, str]: @@ -450,7 +458,7 @@ def token_source_for(audience: str) -> oauth.TokenSource: "subject_token": id_token, "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", }, - scopes=cfg.scopes or "all-apis", + scopes=cfg.get_scopes_as_string(), use_params=True, disable_async=cfg.disable_async_token_refresh, authorization_details=cfg.authorization_details, @@ -533,7 +541,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: }, use_params=True, disable_async=cfg.disable_async_token_refresh, - scopes=cfg.scopes, + scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, ) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 3e6a97abb..d3dac425a 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -661,10 +661,10 @@ def __init__( client_secret: str = None, ): if not scopes: - # all-apis ensures that the returned OAuth token can be used with all APIs, aside - # from direct-to-dataplane APIs. - # offline_access ensures that the response from the Authorization server includes - # a refresh token. + # Default for direct OAuthClient users (e.g., via from_host()). + # When used via credentials_provider.external_browser(), scopes are always + # passed explicitly from Config.get_scopes(), with offline_access handling + # controlled by the disable_oauth_refresh_token flag. scopes = ["all-apis", "offline_access"] self.redirect_url = redirect_url diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py index b8641a45d..6fd273f2a 100644 --- a/databricks/sdk/oidc.py +++ b/databricks/sdk/oidc.py @@ -156,6 +156,7 @@ def __init__( account_id: Optional[str] = None, audience: Optional[str] = None, disable_async: bool = False, + scopes: Optional[str] = None, ): self._host = host self._id_token_source = id_token_source @@ -164,6 +165,7 @@ def __init__( self._account_id = account_id self._audience = audience self._disable_async = disable_async + self._scopes = scopes def token(self) -> oauth.Token: """Get a token by exchanging the ID token. @@ -202,7 +204,7 @@ def _exchange_id_token(self, id_token: IdToken) -> oauth.Token: "subject_token": id_token.jwt, "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", }, - scopes="all-apis", + scopes=self._scopes, use_params=True, disable_async=self._disable_async, ) diff --git a/tests/test_config.py b/tests/test_config.py index 00e7540d9..b860cd10c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,6 +4,8 @@ import random import string from datetime import datetime +from typing import Optional +from urllib.parse import parse_qs import pytest @@ -12,7 +14,7 @@ with_user_agent_extra) from databricks.sdk.version import __version__ -from .conftest import noop_credentials, set_az_path +from .conftest import noop_credentials, set_az_path, set_home __tests__ = os.path.dirname(__file__) @@ -453,3 +455,121 @@ def test_no_org_id_header_on_regular_workspace(requests_mock): # Verify the X-Databricks-Org-Id header was NOT added assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers + + +def test_disable_oauth_refresh_token_from_env(monkeypatch, mocker): + mocker.patch("databricks.sdk.config.Config.init_auth") + monkeypatch.setenv("DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN", "true") + config = Config(host="https://test.databricks.com") + assert config.disable_oauth_refresh_token is True + + +def test_disable_oauth_refresh_token_defaults_to_false(mocker): + mocker.patch("databricks.sdk.config.Config.init_auth") + config = Config(host="https://test.databricks.com") + assert config.disable_oauth_refresh_token is None # ConfigAttribute returns None when not set + + +def test_config_file_scopes_empty_defaults_to_all_apis(monkeypatch, mocker): + """Test that empty scopes in config file defaults to all-apis.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile="scope-empty") + assert config.get_scopes() == ["all-apis"] + + +def test_config_file_scopes_single(monkeypatch, mocker): + """Test single scope from config file.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile="scope-single") + assert config.get_scopes() == ["clusters"] + + +def test_config_file_scopes_multiple_sorted(monkeypatch, mocker): + """Test multiple scopes from config file are sorted.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile="scope-multiple") + # Should be sorted alphabetically + expected = ["clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving:read", "pipelines"] + assert config.get_scopes() == expected + + +def _get_scope_from_request(request_text: str) -> Optional[str]: + """Extract the scope value from a URL-encoded request body.""" + params = parse_qs(request_text) + scope_list = params.get("scope") + return scope_list[0] if scope_list else None + + +@pytest.mark.parametrize( + "scopes_input,expected_scope", + [ + (None, "all-apis"), + (["unity-catalog:read"], "unity-catalog:read"), + (["jobs:read", "clusters", "mlflow:read"], "clusters jobs:read mlflow:read"), + ], + ids=["default_scope", "single_custom_scope", "multiple_scopes_sorted"], +) +def test_m2m_scopes_sent_to_token_endpoint(requests_mock, scopes_input, expected_scope): + """Test M2M authentication sends correct scopes to token endpoint.""" + requests_mock.get( + "https://test.databricks.com/oidc/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize", + "token_endpoint": "https://test.databricks.com/oidc/v1/token", + }, + ) + token_mock = requests_mock.post( + "https://test.databricks.com/oidc/v1/token", + json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600}, + ) + + config = Config( + host="https://test.databricks.com", + client_id="test-client-id", + client_secret="test-client-secret", + auth_type="oauth-m2m", + scopes=scopes_input, + ) + config.authenticate() + + assert _get_scope_from_request(token_mock.last_request.text) == expected_scope + + +@pytest.mark.parametrize( + "scopes_input,expected_scope", + [ + (None, "all-apis"), + (["unity-catalog:read", "clusters"], "clusters unity-catalog:read"), + (["jobs:read"], "jobs:read"), + ], + ids=["default_scope", "multiple_scopes", "single_scope"], +) +def test_oidc_scopes_sent_to_token_endpoint(requests_mock, tmp_path, scopes_input, expected_scope): + """Test OIDC token exchange sends correct scopes to token endpoint.""" + oidc_token_file = tmp_path / "oidc_token" + oidc_token_file.write_text("mock-id-token") + + requests_mock.get( + "https://test.databricks.com/oidc/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize", + "token_endpoint": "https://test.databricks.com/oidc/v1/token", + }, + ) + token_mock = requests_mock.post( + "https://test.databricks.com/oidc/v1/token", + json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600}, + ) + + config = Config( + host="https://test.databricks.com", + oidc_token_filepath=str(oidc_token_file), + auth_type="file-oidc", + scopes=scopes_input, + ) + config.authenticate() + + assert _get_scope_from_request(token_mock.last_request.text) == expected_scope diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index 51dc78eb3..ab9c9c787 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -1,7 +1,10 @@ from datetime import datetime, timedelta from unittest.mock import Mock +import pytest + from databricks.sdk import credentials_provider, oauth, oidc +from databricks.sdk.config import Config # Tests for external_browser function @@ -174,6 +177,53 @@ def test_external_browser_consent_fails(mocker): assert got_credentials_provider is None +def _setup_external_browser_mocks(mocker, cfg): + """Set up mocks for external_browser scope tests. Returns (TokenCache mock, OAuthClient mock).""" + mock_oidc_endpoints = Mock() + mock_oidc_endpoints.token_endpoint = "https://test.databricks.com/oidc/v1/token" + mocker.patch.object(type(cfg), "oidc_endpoints", new_callable=lambda: property(lambda self: mock_oidc_endpoints)) + + mock_token_cache_class = mocker.patch("databricks.sdk.credentials_provider.oauth.TokenCache") + mock_token_cache = Mock() + mock_token_cache.load.return_value = None + mock_token_cache_class.return_value = mock_token_cache + + mock_oauth_client_class = mocker.patch("databricks.sdk.credentials_provider.oauth.OAuthClient") + mock_oauth_client = Mock() + mock_consent = Mock() + mock_consent.launch_external_browser.return_value = Mock() + mock_oauth_client.initiate_consent.return_value = mock_consent + mock_oauth_client_class.return_value = mock_oauth_client + + return mock_token_cache_class, mock_oauth_client_class + + +@pytest.mark.parametrize( + "scopes,disable_refresh,expected_scopes", + [ + (None, False, ["all-apis", "offline_access"]), + ("sql, clusters, jobs", False, ["clusters", "jobs", "sql", "offline_access"]), + (None, True, ["all-apis"]), + ], + ids=["default_scopes", "multiple_scopes_sorted", "disable_offline_access"], +) +def test_external_browser_scopes(mocker, scopes, disable_refresh, expected_scopes): + """Tests that external_browser passes correct scopes to TokenCache and OAuthClient.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + cfg = Config( + host="https://test.databricks.com", + auth_type="external-browser", + scopes=scopes, + disable_oauth_refresh_token=disable_refresh if disable_refresh else None, + ) + mock_token_cache_class, mock_oauth_client_class = _setup_external_browser_mocks(mocker, cfg) + + credentials_provider.external_browser(cfg) + + assert mock_token_cache_class.call_args.kwargs["scopes"] == expected_scopes + assert mock_oauth_client_class.call_args.kwargs["scopes"] == expected_scopes + + def test_oidc_credentials_provider_invalid_id_token_source(): # Use a mock config object to avoid initializing the auth initialization. mock_cfg = Mock() diff --git a/tests/test_notebook_oauth.py b/tests/test_notebook_oauth.py index 55e5237d6..0338f7345 100644 --- a/tests/test_notebook_oauth.py +++ b/tests/test_notebook_oauth.py @@ -78,10 +78,10 @@ def credentials_provider() -> Dict[str, str]: @pytest.mark.parametrize( "scopes,auth_details", [ - ("sql offline_access", None), - ("sql offline_access", '{"type": "databricks_resource"}'), + ("sql, offline_access", None), + ("sql, offline_access", '{"type": "databricks_resource"}'), ("sql", None), - ("sql offline_access all-apis", None), + ("sql, offline_access, all-apis", None), ], ) def test_runtime_oauth_success_scenarios( @@ -117,7 +117,7 @@ def test_runtime_oauth_missing_scopes(mock_runtime_env, mock_runtime_native_auth def test_runtime_oauth_priority_over_native_auth(mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange): """Test that runtime-oauth is prioritized over runtime-native-auth.""" - cfg = Config(host="https://test.cloud.databricks.com", scopes="sql offline_access") + cfg = Config(host="https://test.cloud.databricks.com", scopes="sql, offline_access") default_creds = DefaultCredentials() creds_provider = default_creds(cfg) @@ -141,7 +141,7 @@ def test_fallback_to_native_auth_without_scopes(mock_runtime_env, mock_runtime_n def test_explicit_runtime_oauth_auth_type(mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange): """Test that runtime-oauth is used when explicitly specified as auth_type.""" - cfg = Config(host="https://test.cloud.databricks.com", scopes="sql offline_access", auth_type="runtime-oauth") + cfg = Config(host="https://test.cloud.databricks.com", scopes="sql, offline_access", auth_type="runtime-oauth") default_creds = DefaultCredentials() creds_provider = default_creds(cfg) @@ -164,7 +164,7 @@ def test_config_authenticate_integration( """Test Config.authenticate() integration with runtime-oauth and fallback.""" cfg_kwargs = {"host": "https://test.cloud.databricks.com"} if has_scopes: - cfg_kwargs["scopes"] = "sql offline_access" + cfg_kwargs["scopes"] = "sql, offline_access" cfg = Config(**cfg_kwargs) headers = cfg.authenticate() @@ -174,7 +174,7 @@ def test_config_authenticate_integration( @pytest.mark.parametrize( "scopes_input,expected_scopes", - [(["sql", "offline_access"], "sql offline_access")], + [(["sql", "offline_access"], ["offline_access", "sql"])], ) def test_workspace_client_integration( mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, scopes_input, expected_scopes diff --git a/tests/testdata/.databrickscfg b/tests/testdata/.databrickscfg index 2759b6c1b..2ffb627ae 100644 --- a/tests/testdata/.databrickscfg +++ b/tests/testdata/.databrickscfg @@ -38,4 +38,15 @@ google_credentials = paw48590aw8e09t8apu [pat.with.dot] host = https://dbc-XXXXXXXX-YYYY.cloud.databricks.com/ -token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ \ No newline at end of file +token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ + +[scope-empty] +host = https://example.cloud.databricks.com + +[scope-single] +host = https://example.cloud.databricks.com +scopes = clusters + +[scope-multiple] +host = https://example.cloud.databricks.com +scopes = clusters, jobs, pipelines, iam:read, files:read, mlflow, model-serving:read