Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 38 additions & 4 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ class ConfigAttribute:
# name and transform are discovered from Config.__new__
name: str = None
transform: type = str
_custom_transform = None

def __init__(self, env: str = None, auth: str = None, sensitive: bool = False):
def __init__(self, env: str = None, auth: str = None, sensitive: bool = False, transform=None):
self.env = env
self.auth = auth
self.sensitive = sensitive
self._custom_transform = transform

def __get__(self, cfg: "Config", owner):
if not cfg:
Expand All @@ -64,6 +66,19 @@ def __repr__(self) -> str:
return f"<ConfigAttribute '{self.name}' {self.transform.__name__}>"


def _parse_scopes(value):
"""Parse scopes into a deduplicated, sorted list."""
if value is None:
return None
if isinstance(value, list):
result = sorted(set(s for s in value if s))
return result if result else None
if isinstance(value, str):
parsed = sorted(set(s.strip() for s in value.split(",") if s.strip()))
return parsed if parsed else None
return None


def with_product(product: str, product_version: str):
"""[INTERNAL API] Change the product name and version used in the User-Agent header."""
useragent.with_product(product, product_version)
Expand Down Expand Up @@ -133,10 +148,14 @@ class Config:
disable_experimental_files_api_client: bool = ConfigAttribute(
env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT"
)
# TODO: Expose these via environment variables too.
scopes: str = ConfigAttribute()

scopes: list = ConfigAttribute(transform=_parse_scopes)
authorization_details: str = ConfigAttribute()

# disable_oauth_refresh_token controls whether a refresh token should be requested
# during the U2M authentication flow (default to false).
disable_oauth_refresh_token: bool = ConfigAttribute(env="DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN")

files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB

# When downloading a file, the maximum number of attempts to retry downloading the whole file. Default is no limit.
Expand Down Expand Up @@ -553,7 +572,7 @@ def attributes(cls) -> Iterable[ConfigAttribute]:
if type(v) != ConfigAttribute:
continue
v.name = name
v.transform = anno.get(name, str)
v.transform = v._custom_transform if v._custom_transform else anno.get(name, str)
attrs.append(v)
cls._attributes = attrs
return cls._attributes
Expand Down Expand Up @@ -685,6 +704,21 @@ def _init_product(self, product, product_version):
else:
self._product_info = None

def get_scopes(self) -> list:
"""Get OAuth scopes with proper defaulting.

Returns ["all-apis"] if no scopes configured.
This is the single source of truth for scope defaulting across all OAuth methods.
"""
return self.scopes if self.scopes else ["all-apis"]

def get_scopes_as_string(self) -> str:
"""Get OAuth scopes as a space-separated string.

Returns "all-apis" if no scopes configured.
"""
return " ".join(self.get_scopes())

def __repr__(self):
return f"<{self.debug_string()}>"

Expand Down
18 changes: 13 additions & 5 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def get_notebook_pat_token() -> Optional[str]:
token_source = oauth.PATOAuthTokenExchange(
get_original_token=get_notebook_pat_token,
host=cfg.host,
scopes=cfg.scopes,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
)

Expand All @@ -225,7 +225,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
client_id=cfg.client_id,
client_secret=cfg.client_secret,
token_url=oidc.token_endpoint,
scopes=cfg.scopes or "all-apis",
scopes=cfg.get_scopes_as_string(),
use_header=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
Expand Down Expand Up @@ -256,6 +256,11 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
if not client_id:
client_id = "databricks-cli"

scopes = cfg.get_scopes()
if not cfg.disable_oauth_refresh_token:
if "offline_access" not in scopes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a couple of tests to (1) ensure that offline_access is added (2) only if not present.

scopes = scopes + ["offline_access"]

# Load cached credentials from disk if they exist. Note that these are
# local to the Python SDK and not reused by other SDKs.
oidc_endpoints = cfg.oidc_endpoints
Expand All @@ -266,6 +271,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
client_id=client_id,
client_secret=client_secret,
redirect_url=redirect_url,
scopes=scopes,
)
credentials = token_cache.load()
if credentials:
Expand All @@ -284,6 +290,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
client_id=client_id,
redirect_url=redirect_url,
client_secret=client_secret,
scopes=scopes,
)
consent = oauth_client.initiate_consent()
if not consent:
Expand Down Expand Up @@ -329,7 +336,7 @@ def token_source_for(resource: str) -> oauth.TokenSource:
endpoint_params={"resource": resource},
use_params=True,
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.scopes,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
)

Expand Down Expand Up @@ -387,6 +394,7 @@ def oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optio
account_id=cfg.account_id,
id_token_source=id_token_source,
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.get_scopes_as_string(),
)

def refreshed_headers() -> Dict[str, str]:
Expand Down Expand Up @@ -450,7 +458,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
"subject_token": id_token,
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
},
scopes=cfg.scopes or "all-apis",
scopes=cfg.get_scopes_as_string(),
use_params=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
Expand Down Expand Up @@ -533,7 +541,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
},
use_params=True,
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.scopes,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
)

Expand Down
8 changes: 4 additions & 4 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,10 @@ def __init__(
client_secret: str = None,
):
if not scopes:
# all-apis ensures that the returned OAuth token can be used with all APIs, aside
# from direct-to-dataplane APIs.
# offline_access ensures that the response from the Authorization server includes
# a refresh token.
# Default for direct OAuthClient users (e.g., via from_host()).
# When used via credentials_provider.external_browser(), scopes are always
# passed explicitly from Config.get_scopes(), with offline_access handling
# controlled by the disable_oauth_refresh_token flag.
scopes = ["all-apis", "offline_access"]

self.redirect_url = redirect_url
Expand Down
4 changes: 3 additions & 1 deletion databricks/sdk/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
account_id: Optional[str] = None,
audience: Optional[str] = None,
disable_async: bool = False,
scopes: Optional[str] = None,
):
self._host = host
self._id_token_source = id_token_source
Expand All @@ -164,6 +165,7 @@ def __init__(
self._account_id = account_id
self._audience = audience
self._disable_async = disable_async
self._scopes = scopes

def token(self) -> oauth.Token:
"""Get a token by exchanging the ID token.
Expand Down Expand Up @@ -202,7 +204,7 @@ def _exchange_id_token(self, id_token: IdToken) -> oauth.Token:
"subject_token": id_token.jwt,
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
},
scopes="all-apis",
scopes=self._scopes,
use_params=True,
disable_async=self._disable_async,
)
Expand Down
122 changes: 121 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import random
import string
from datetime import datetime
from typing import Optional
from urllib.parse import parse_qs

import pytest

Expand All @@ -12,7 +14,7 @@
with_user_agent_extra)
from databricks.sdk.version import __version__

from .conftest import noop_credentials, set_az_path
from .conftest import noop_credentials, set_az_path, set_home

__tests__ = os.path.dirname(__file__)

Expand Down Expand Up @@ -453,3 +455,121 @@ def test_no_org_id_header_on_regular_workspace(requests_mock):

# Verify the X-Databricks-Org-Id header was NOT added
assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers


def test_disable_oauth_refresh_token_from_env(monkeypatch, mocker):
mocker.patch("databricks.sdk.config.Config.init_auth")
monkeypatch.setenv("DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN", "true")
config = Config(host="https://test.databricks.com")
assert config.disable_oauth_refresh_token is True


def test_disable_oauth_refresh_token_defaults_to_false(mocker):
mocker.patch("databricks.sdk.config.Config.init_auth")
config = Config(host="https://test.databricks.com")
assert config.disable_oauth_refresh_token is None # ConfigAttribute returns None when not set


def test_config_file_scopes_empty_defaults_to_all_apis(monkeypatch, mocker):
"""Test that empty scopes in config file defaults to all-apis."""
mocker.patch("databricks.sdk.config.Config.init_auth")
set_home(monkeypatch, "/testdata")
config = Config(profile="scope-empty")
assert config.get_scopes() == ["all-apis"]


def test_config_file_scopes_single(monkeypatch, mocker):
"""Test single scope from config file."""
mocker.patch("databricks.sdk.config.Config.init_auth")
set_home(monkeypatch, "/testdata")
config = Config(profile="scope-single")
assert config.get_scopes() == ["clusters"]


def test_config_file_scopes_multiple_sorted(monkeypatch, mocker):
"""Test multiple scopes from config file are sorted."""
mocker.patch("databricks.sdk.config.Config.init_auth")
set_home(monkeypatch, "/testdata")
config = Config(profile="scope-multiple")
# Should be sorted alphabetically
expected = ["clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving:read", "pipelines"]
assert config.get_scopes() == expected


def _get_scope_from_request(request_text: str) -> Optional[str]:
"""Extract the scope value from a URL-encoded request body."""
params = parse_qs(request_text)
scope_list = params.get("scope")
return scope_list[0] if scope_list else None


@pytest.mark.parametrize(
"scopes_input,expected_scope",
[
(None, "all-apis"),
(["unity-catalog:read"], "unity-catalog:read"),
(["jobs:read", "clusters", "mlflow:read"], "clusters jobs:read mlflow:read"),
],
ids=["default_scope", "single_custom_scope", "multiple_scopes_sorted"],
)
def test_m2m_scopes_sent_to_token_endpoint(requests_mock, scopes_input, expected_scope):
"""Test M2M authentication sends correct scopes to token endpoint."""
requests_mock.get(
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
},
)
token_mock = requests_mock.post(
"https://test.databricks.com/oidc/v1/token",
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
)

config = Config(
host="https://test.databricks.com",
client_id="test-client-id",
client_secret="test-client-secret",
auth_type="oauth-m2m",
scopes=scopes_input,
)
config.authenticate()

assert _get_scope_from_request(token_mock.last_request.text) == expected_scope


@pytest.mark.parametrize(
"scopes_input,expected_scope",
[
(None, "all-apis"),
(["unity-catalog:read", "clusters"], "clusters unity-catalog:read"),
(["jobs:read"], "jobs:read"),
],
ids=["default_scope", "multiple_scopes", "single_scope"],
)
def test_oidc_scopes_sent_to_token_endpoint(requests_mock, tmp_path, scopes_input, expected_scope):
"""Test OIDC token exchange sends correct scopes to token endpoint."""
oidc_token_file = tmp_path / "oidc_token"
oidc_token_file.write_text("mock-id-token")

requests_mock.get(
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
},
)
token_mock = requests_mock.post(
"https://test.databricks.com/oidc/v1/token",
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
)

config = Config(
host="https://test.databricks.com",
oidc_token_filepath=str(oidc_token_file),
auth_type="file-oidc",
scopes=scopes_input,
)
config.authenticate()

assert _get_scope_from_request(token_mock.last_request.text) == expected_scope
Loading
Loading