Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
client_id=cfg.client_id,
client_secret=cfg.client_secret,
token_url=oidc.token_endpoint,
scopes=cfg.scopes or "all-apis",
scopes=cfg.get_scopes_as_string(),
use_header=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
Expand Down Expand Up @@ -256,6 +256,11 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
if not client_id:
client_id = "databricks-cli"

scopes = cfg.get_scopes()
if not cfg.disable_oauth_refresh_token:
if "offline_access" not in scopes:
scopes = scopes + ["offline_access"]

# Load cached credentials from disk if they exist. Note that these are
# local to the Python SDK and not reused by other SDKs.
oidc_endpoints = cfg.oidc_endpoints
Expand All @@ -266,6 +271,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
client_id=client_id,
client_secret=client_secret,
redirect_url=redirect_url,
scopes=scopes,
)
credentials = token_cache.load()
if credentials:
Expand All @@ -284,6 +290,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
client_id=client_id,
redirect_url=redirect_url,
client_secret=client_secret,
scopes=scopes,
)
consent = oauth_client.initiate_consent()
if not consent:
Expand Down Expand Up @@ -387,6 +394,7 @@ def oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optio
account_id=cfg.account_id,
id_token_source=id_token_source,
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.get_scopes_as_string(),
)

def refreshed_headers() -> Dict[str, str]:
Expand Down Expand Up @@ -450,7 +458,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
"subject_token": id_token,
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
},
scopes=cfg.scopes or "all-apis",
scopes=cfg.get_scopes_as_string(),
use_params=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
Expand Down
8 changes: 4 additions & 4 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,10 @@ def __init__(
client_secret: str = None,
):
if not scopes:
# all-apis ensures that the returned OAuth token can be used with all APIs, aside
# from direct-to-dataplane APIs.
# offline_access ensures that the response from the Authorization server includes
# a refresh token.
# Default for direct OAuthClient users (e.g., via from_host()).
# When used via credentials_provider.external_browser(), scopes are always
# passed explicitly from Config.get_scopes(), with offline_access handling
# controlled by the disable_oauth_refresh_token flag.
scopes = ["all-apis", "offline_access"]

self.redirect_url = redirect_url
Expand Down
4 changes: 3 additions & 1 deletion databricks/sdk/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
account_id: Optional[str] = None,
audience: Optional[str] = None,
disable_async: bool = False,
scopes: Optional[str] = None,
):
self._host = host
self._id_token_source = id_token_source
Expand All @@ -164,6 +165,7 @@ def __init__(
self._account_id = account_id
self._audience = audience
self._disable_async = disable_async
self._scopes = scopes

def token(self) -> oauth.Token:
"""Get a token by exchanging the ID token.
Expand Down Expand Up @@ -202,7 +204,7 @@ def _exchange_id_token(self, id_token: IdToken) -> oauth.Token:
"subject_token": id_token.jwt,
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
},
scopes="all-apis",
scopes=self._scopes,
use_params=True,
disable_async=self._disable_async,
)
Expand Down
91 changes: 91 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import random
import string
from datetime import datetime
from typing import Optional
from urllib.parse import parse_qs

import pytest

Expand Down Expand Up @@ -530,3 +532,92 @@ def test_scopes_parsing(mocker, scopes_input, expected_scopes):
mocker.patch("databricks.sdk.config.Config.init_auth")
config = Config(host="https://test.databricks.com", scopes=scopes_input)
assert config.get_scopes() == expected_scopes


def test_config_file_scopes_multiple_sorted(monkeypatch, mocker):
"""Test multiple scopes from config file are sorted."""
mocker.patch("databricks.sdk.config.Config.init_auth")
set_home(monkeypatch, "/testdata")
config = Config(profile="scope-multiple")
# Should be sorted alphabetically
expected = ["clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving:read", "pipelines"]
assert config.get_scopes() == expected


def _get_scope_from_request(request_text: str) -> Optional[str]:
"""Extract the scope value from a URL-encoded request body."""
params = parse_qs(request_text)
scope_list = params.get("scope")
return scope_list[0] if scope_list else None


@pytest.mark.parametrize(
"scopes_input,expected_scope",
[
(None, "all-apis"),
(["unity-catalog:read"], "unity-catalog:read"),
(["jobs:read", "clusters", "mlflow:read"], "clusters jobs:read mlflow:read"),
],
ids=["default_scope", "single_custom_scope", "multiple_scopes_sorted"],
)
def test_m2m_scopes_sent_to_token_endpoint(requests_mock, scopes_input, expected_scope):
"""Test M2M authentication sends correct scopes to token endpoint."""
requests_mock.get(
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
},
)
token_mock = requests_mock.post(
"https://test.databricks.com/oidc/v1/token",
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
)

config = Config(
host="https://test.databricks.com",
client_id="test-client-id",
client_secret="test-client-secret",
auth_type="oauth-m2m",
scopes=scopes_input,
)
config.authenticate()

assert _get_scope_from_request(token_mock.last_request.text) == expected_scope


@pytest.mark.parametrize(
"scopes_input,expected_scope",
[
(None, "all-apis"),
(["unity-catalog:read", "clusters"], "clusters unity-catalog:read"),
(["jobs:read"], "jobs:read"),
],
ids=["default_scope", "multiple_scopes", "single_scope"],
)
def test_oidc_scopes_sent_to_token_endpoint(requests_mock, tmp_path, scopes_input, expected_scope):
"""Test OIDC token exchange sends correct scopes to token endpoint."""
oidc_token_file = tmp_path / "oidc_token"
oidc_token_file.write_text("mock-id-token")

requests_mock.get(
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
},
)
token_mock = requests_mock.post(
"https://test.databricks.com/oidc/v1/token",
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
)

config = Config(
host="https://test.databricks.com",
oidc_token_filepath=str(oidc_token_file),
auth_type="file-oidc",
scopes=scopes_input,
)
config.authenticate()

assert _get_scope_from_request(token_mock.last_request.text) == expected_scope
51 changes: 51 additions & 0 deletions tests/test_credentials_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from datetime import datetime, timedelta
from unittest.mock import Mock

import pytest

from databricks.sdk import credentials_provider, oauth, oidc
from databricks.sdk.config import Config


# Tests for external_browser function
Expand Down Expand Up @@ -174,6 +177,54 @@ def test_external_browser_consent_fails(mocker):
assert got_credentials_provider is None


def _setup_external_browser_mocks(mocker, cfg):
"""Set up mocks for external_browser scope tests. Returns (TokenCache mock, OAuthClient mock)."""
mock_oidc_endpoints = Mock()
mock_oidc_endpoints.token_endpoint = "https://test.databricks.com/oidc/v1/token"
mocker.patch.object(type(cfg), "oidc_endpoints", new_callable=lambda: property(lambda self: mock_oidc_endpoints))

mock_token_cache_class = mocker.patch("databricks.sdk.credentials_provider.oauth.TokenCache")
mock_token_cache = Mock()
mock_token_cache.load.return_value = None
mock_token_cache_class.return_value = mock_token_cache

mock_oauth_client_class = mocker.patch("databricks.sdk.credentials_provider.oauth.OAuthClient")
mock_oauth_client = Mock()
mock_consent = Mock()
mock_consent.launch_external_browser.return_value = Mock()
mock_oauth_client.initiate_consent.return_value = mock_consent
mock_oauth_client_class.return_value = mock_oauth_client

return mock_token_cache_class, mock_oauth_client_class


@pytest.mark.parametrize(
"scopes,disable_refresh,expected_scopes",
[
(None, False, ["all-apis", "offline_access"]),
("sql, clusters, jobs", False, ["clusters", "jobs", "sql", "offline_access"]),
(None, True, ["all-apis"]),
("sql, clusters, jobs, offline_access", False, ["clusters", "jobs", "offline_access", "sql"]),
],
ids=["default_scopes", "multiple_scopes_sorted", "disable_offline_access", "offline_access_not_duplicated"],
)
def test_external_browser_scopes(mocker, scopes, disable_refresh, expected_scopes):
"""Tests that external_browser passes correct scopes to TokenCache and OAuthClient."""
mocker.patch("databricks.sdk.config.Config.init_auth")
cfg = Config(
host="https://test.databricks.com",
auth_type="external-browser",
scopes=scopes,
disable_oauth_refresh_token=disable_refresh if disable_refresh else None,
)
mock_token_cache_class, mock_oauth_client_class = _setup_external_browser_mocks(mocker, cfg)

credentials_provider.external_browser(cfg)

assert mock_token_cache_class.call_args.kwargs["scopes"] == expected_scopes
assert mock_oauth_client_class.call_args.kwargs["scopes"] == expected_scopes


def test_oidc_credentials_provider_invalid_id_token_source():
# Use a mock config object to avoid initializing the auth initialization.
mock_cfg = Mock()
Expand Down
Loading