diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index dcf680a8c..e4485a10a 100644 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -253,7 +253,7 @@ def __init__( product=product, product_version=product_version, token_audience=token_audience, - scopes=" ".join(scopes) if scopes else None, + scopes=scopes, authorization_details=( json.dumps([detail.as_dict() for detail in authorization_details]) if authorization_details diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index fa61bdbb9..8642d0f6d 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -4,6 +4,7 @@ import logging import os import pathlib +import re import sys import urllib.parse from enum import Enum @@ -46,11 +47,13 @@ class ConfigAttribute: # name and transform are discovered from Config.__new__ name: str = None transform: type = str + _custom_transform = None - def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): + def __init__(self, env: str = None, auth: str = None, sensitive: bool = False, transform=None): self.env = env self.auth = auth self.sensitive = sensitive + self._custom_transform = transform def __get__(self, cfg: "Config", owner): if not cfg: @@ -64,6 +67,19 @@ def __repr__(self) -> str: return f"" +def _parse_scopes(value): + """Parse scopes into a deduplicated, sorted list.""" + if value is None: + return None + if isinstance(value, list): + result = sorted(set(s for s in value if s)) + return result if result else None + if isinstance(value, str): + parsed: list = sorted(set(s for s in re.split(r"[, ]+", value) if s)) + return parsed if parsed else None + return None + + def with_product(product: str, product_version: str): """[INTERNAL API] Change the product name and version used in the User-Agent header.""" useragent.with_product(product, product_version) @@ -133,10 +149,14 @@ class Config: disable_experimental_files_api_client: bool = ConfigAttribute( env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT" ) - # TODO: Expose these via environment variables too. - scopes: str = ConfigAttribute() + + scopes: list = ConfigAttribute(transform=_parse_scopes) authorization_details: str = ConfigAttribute() + # disable_oauth_refresh_token controls whether a refresh token should be requested + # during the U2M authentication flow (default to false). + disable_oauth_refresh_token: bool = ConfigAttribute(env="DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN") + files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB # When downloading a file, the maximum number of attempts to retry downloading the whole file. Default is no limit. @@ -553,7 +573,7 @@ def attributes(cls) -> Iterable[ConfigAttribute]: if type(v) != ConfigAttribute: continue v.name = name - v.transform = anno.get(name, str) + v.transform = v._custom_transform if v._custom_transform else anno.get(name, str) attrs.append(v) cls._attributes = attrs return cls._attributes @@ -685,6 +705,21 @@ def _init_product(self, product, product_version): else: self._product_info = None + def get_scopes(self) -> list: + """Get OAuth scopes with proper defaulting. + + Returns ["all-apis"] if no scopes configured. + This is the single source of truth for scope defaulting across all OAuth methods. + """ + return self.scopes if self.scopes else ["all-apis"] + + def get_scopes_as_string(self) -> str: + """Get OAuth scopes as a space-separated string. + + Returns "all-apis" if no scopes configured. + """ + return " ".join(self.get_scopes()) + def __repr__(self): return f"<{self.debug_string()}>" diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 926c50a05..ebf6fa2bd 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -198,7 +198,7 @@ def get_notebook_pat_token() -> Optional[str]: token_source = oauth.PATOAuthTokenExchange( get_original_token=get_notebook_pat_token, host=cfg.host, - scopes=cfg.scopes, + scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, ) @@ -329,7 +329,7 @@ def token_source_for(resource: str) -> oauth.TokenSource: endpoint_params={"resource": resource}, use_params=True, disable_async=cfg.disable_async_token_refresh, - scopes=cfg.scopes, + scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, ) @@ -533,7 +533,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: }, use_params=True, disable_async=cfg.disable_async_token_refresh, - scopes=cfg.scopes, + scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, ) diff --git a/tests/test_config.py b/tests/test_config.py index 00e7540d9..1622cb330 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,7 +12,7 @@ with_user_agent_extra) from databricks.sdk.version import __version__ -from .conftest import noop_credentials, set_az_path +from .conftest import noop_credentials, set_az_path, set_home __tests__ = os.path.dirname(__file__) @@ -453,3 +453,80 @@ def test_no_org_id_header_on_regular_workspace(requests_mock): # Verify the X-Databricks-Org-Id header was NOT added assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers + + +def test_disable_oauth_refresh_token_from_env(monkeypatch, mocker): + mocker.patch("databricks.sdk.config.Config.init_auth") + monkeypatch.setenv("DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN", "true") + config = Config(host="https://test.databricks.com") + assert config.disable_oauth_refresh_token is True + + +def test_disable_oauth_refresh_token_defaults_to_false(mocker): + mocker.patch("databricks.sdk.config.Config.init_auth") + config = Config(host="https://test.databricks.com") + assert not config.disable_oauth_refresh_token + + +@pytest.mark.parametrize( + "profile,expected_scopes", + [ + ("scope-empty", ["all-apis"]), + ("scope-single", ["clusters"]), + ("scope-multiple", ["clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving:read", "pipelines"]), + ], + ids=["empty_defaults_to_all_apis", "single_scope", "multiple_sorted"], +) +def test_config_file_scopes(monkeypatch, mocker, profile, expected_scopes): + """Test scopes from config file profiles.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile=profile) + assert config.get_scopes() == expected_scopes + + +@pytest.mark.parametrize( + "scopes_input,expected_scopes", + [ + # List input + (["jobs", "clusters", "mlflow"], ["clusters", "jobs", "mlflow"]), + # Deduplication (list) + (["clusters", "jobs", "clusters", "jobs", "mlflow"], ["clusters", "jobs", "mlflow"]), + # Deduplication (string) + ("clusters,jobs,clusters,jobs,mlflow", ["clusters", "jobs", "mlflow"]), + # Space-separated (backwards compatibility) + ("clusters jobs mlflow", ["clusters", "jobs", "mlflow"]), + # Mixed separators + ("clusters, jobs mlflow,pipelines", ["clusters", "jobs", "mlflow", "pipelines"]), + # Empty string defaults to all-apis + ("", ["all-apis"]), + # Whitespace-only defaults to all-apis + (" ", ["all-apis"]), + # None defaults to all-apis + (None, ["all-apis"]), + # Empty list defaults to all-apis + ([], ["all-apis"]), + # Empty strings in list are filtered + (["clusters", "", "jobs", ""], ["clusters", "jobs"]), + # List with only empty strings defaults to all-apis + (["", "", ""], ["all-apis"]), + ], + ids=[ + "list_input", + "deduplication_list", + "deduplication_string", + "space_separated", + "mixed_separators", + "empty_string", + "whitespace_only", + "none", + "empty_list", + "list_with_empty_strings", + "list_only_empty_strings", + ], +) +def test_scopes_parsing(mocker, scopes_input, expected_scopes): + """Test scopes parsing with various input formats.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + config = Config(host="https://test.databricks.com", scopes=scopes_input) + assert config.get_scopes() == expected_scopes diff --git a/tests/test_notebook_oauth.py b/tests/test_notebook_oauth.py index 55e5237d6..bde8f2165 100644 --- a/tests/test_notebook_oauth.py +++ b/tests/test_notebook_oauth.py @@ -174,7 +174,7 @@ def test_config_authenticate_integration( @pytest.mark.parametrize( "scopes_input,expected_scopes", - [(["sql", "offline_access"], "sql offline_access")], + [(["sql", "offline_access"], ["offline_access", "sql"])], ) def test_workspace_client_integration( mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, scopes_input, expected_scopes diff --git a/tests/testdata/.databrickscfg b/tests/testdata/.databrickscfg index 2759b6c1b..2ffb627ae 100644 --- a/tests/testdata/.databrickscfg +++ b/tests/testdata/.databrickscfg @@ -38,4 +38,15 @@ google_credentials = paw48590aw8e09t8apu [pat.with.dot] host = https://dbc-XXXXXXXX-YYYY.cloud.databricks.com/ -token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ \ No newline at end of file +token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ + +[scope-empty] +host = https://example.cloud.databricks.com + +[scope-single] +host = https://example.cloud.databricks.com +scopes = clusters + +[scope-multiple] +host = https://example.cloud.databricks.com +scopes = clusters, jobs, pipelines, iam:read, files:read, mlflow, model-serving:read