Skip to content

Commit 1e31fec

Browse files
Custom scopes support in M2M, WIF and U2M (#1194)
## Summary - Update M2M, WIF and U2M auth to use get_scopes instead of hard-coded values. ## How is this tested? - Integration tests with mocking to ensure correct scopes propogated to token requests. --- NO_CHANGELOG=true
1 parent 0ec836e commit 1e31fec

File tree

5 files changed

+159
-7
lines changed

5 files changed

+159
-7
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
225225
client_id=cfg.client_id,
226226
client_secret=cfg.client_secret,
227227
token_url=oidc.token_endpoint,
228-
scopes=cfg.scopes or "all-apis",
228+
scopes=cfg.get_scopes_as_string(),
229229
use_header=True,
230230
disable_async=cfg.disable_async_token_refresh,
231231
authorization_details=cfg.authorization_details,
@@ -256,6 +256,11 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
256256
if not client_id:
257257
client_id = "databricks-cli"
258258

259+
scopes = cfg.get_scopes()
260+
if not cfg.disable_oauth_refresh_token:
261+
if "offline_access" not in scopes:
262+
scopes = scopes + ["offline_access"]
263+
259264
# Load cached credentials from disk if they exist. Note that these are
260265
# local to the Python SDK and not reused by other SDKs.
261266
oidc_endpoints = cfg.oidc_endpoints
@@ -266,6 +271,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
266271
client_id=client_id,
267272
client_secret=client_secret,
268273
redirect_url=redirect_url,
274+
scopes=scopes,
269275
)
270276
credentials = token_cache.load()
271277
if credentials:
@@ -284,6 +290,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
284290
client_id=client_id,
285291
redirect_url=redirect_url,
286292
client_secret=client_secret,
293+
scopes=scopes,
287294
)
288295
consent = oauth_client.initiate_consent()
289296
if not consent:
@@ -387,6 +394,7 @@ def oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optio
387394
account_id=cfg.account_id,
388395
id_token_source=id_token_source,
389396
disable_async=cfg.disable_async_token_refresh,
397+
scopes=cfg.get_scopes_as_string(),
390398
)
391399

392400
def refreshed_headers() -> Dict[str, str]:
@@ -450,7 +458,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
450458
"subject_token": id_token,
451459
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
452460
},
453-
scopes=cfg.scopes or "all-apis",
461+
scopes=cfg.get_scopes_as_string(),
454462
use_params=True,
455463
disable_async=cfg.disable_async_token_refresh,
456464
authorization_details=cfg.authorization_details,

databricks/sdk/oauth.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -663,10 +663,10 @@ def __init__(
663663
client_secret: str = None,
664664
):
665665
if not scopes:
666-
# all-apis ensures that the returned OAuth token can be used with all APIs, aside
667-
# from direct-to-dataplane APIs.
668-
# offline_access ensures that the response from the Authorization server includes
669-
# a refresh token.
666+
# Default for direct OAuthClient users (e.g., via from_host()).
667+
# When used via credentials_provider.external_browser(), scopes are always
668+
# passed explicitly from Config.get_scopes(), with offline_access handling
669+
# controlled by the disable_oauth_refresh_token flag.
670670
scopes = ["all-apis", "offline_access"]
671671

672672
self.redirect_url = redirect_url

databricks/sdk/oidc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
account_id: Optional[str] = None,
157157
audience: Optional[str] = None,
158158
disable_async: bool = False,
159+
scopes: Optional[str] = None,
159160
):
160161
self._host = host
161162
self._id_token_source = id_token_source
@@ -164,6 +165,7 @@ def __init__(
164165
self._account_id = account_id
165166
self._audience = audience
166167
self._disable_async = disable_async
168+
self._scopes = scopes
167169

168170
def token(self) -> oauth.Token:
169171
"""Get a token by exchanging the ID token.
@@ -202,7 +204,7 @@ def _exchange_id_token(self, id_token: IdToken) -> oauth.Token:
202204
"subject_token": id_token.jwt,
203205
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
204206
},
205-
scopes="all-apis",
207+
scopes=self._scopes,
206208
use_params=True,
207209
disable_async=self._disable_async,
208210
)

tests/test_config.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import random
55
import string
66
from datetime import datetime
7+
from typing import Optional
8+
from urllib.parse import parse_qs
79

810
import pytest
911

@@ -530,3 +532,92 @@ def test_scopes_parsing(mocker, scopes_input, expected_scopes):
530532
mocker.patch("databricks.sdk.config.Config.init_auth")
531533
config = Config(host="https://test.databricks.com", scopes=scopes_input)
532534
assert config.get_scopes() == expected_scopes
535+
536+
537+
def test_config_file_scopes_multiple_sorted(monkeypatch, mocker):
538+
"""Test multiple scopes from config file are sorted."""
539+
mocker.patch("databricks.sdk.config.Config.init_auth")
540+
set_home(monkeypatch, "/testdata")
541+
config = Config(profile="scope-multiple")
542+
# Should be sorted alphabetically
543+
expected = ["clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving:read", "pipelines"]
544+
assert config.get_scopes() == expected
545+
546+
547+
def _get_scope_from_request(request_text: str) -> Optional[str]:
548+
"""Extract the scope value from a URL-encoded request body."""
549+
params = parse_qs(request_text)
550+
scope_list = params.get("scope")
551+
return scope_list[0] if scope_list else None
552+
553+
554+
@pytest.mark.parametrize(
555+
"scopes_input,expected_scope",
556+
[
557+
(None, "all-apis"),
558+
(["unity-catalog:read"], "unity-catalog:read"),
559+
(["jobs:read", "clusters", "mlflow:read"], "clusters jobs:read mlflow:read"),
560+
],
561+
ids=["default_scope", "single_custom_scope", "multiple_scopes_sorted"],
562+
)
563+
def test_m2m_scopes_sent_to_token_endpoint(requests_mock, scopes_input, expected_scope):
564+
"""Test M2M authentication sends correct scopes to token endpoint."""
565+
requests_mock.get(
566+
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
567+
json={
568+
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
569+
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
570+
},
571+
)
572+
token_mock = requests_mock.post(
573+
"https://test.databricks.com/oidc/v1/token",
574+
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
575+
)
576+
577+
config = Config(
578+
host="https://test.databricks.com",
579+
client_id="test-client-id",
580+
client_secret="test-client-secret",
581+
auth_type="oauth-m2m",
582+
scopes=scopes_input,
583+
)
584+
config.authenticate()
585+
586+
assert _get_scope_from_request(token_mock.last_request.text) == expected_scope
587+
588+
589+
@pytest.mark.parametrize(
590+
"scopes_input,expected_scope",
591+
[
592+
(None, "all-apis"),
593+
(["unity-catalog:read", "clusters"], "clusters unity-catalog:read"),
594+
(["jobs:read"], "jobs:read"),
595+
],
596+
ids=["default_scope", "multiple_scopes", "single_scope"],
597+
)
598+
def test_oidc_scopes_sent_to_token_endpoint(requests_mock, tmp_path, scopes_input, expected_scope):
599+
"""Test OIDC token exchange sends correct scopes to token endpoint."""
600+
oidc_token_file = tmp_path / "oidc_token"
601+
oidc_token_file.write_text("mock-id-token")
602+
603+
requests_mock.get(
604+
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
605+
json={
606+
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
607+
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
608+
},
609+
)
610+
token_mock = requests_mock.post(
611+
"https://test.databricks.com/oidc/v1/token",
612+
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
613+
)
614+
615+
config = Config(
616+
host="https://test.databricks.com",
617+
oidc_token_filepath=str(oidc_token_file),
618+
auth_type="file-oidc",
619+
scopes=scopes_input,
620+
)
621+
config.authenticate()
622+
623+
assert _get_scope_from_request(token_mock.last_request.text) == expected_scope

tests/test_credentials_provider.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from datetime import datetime, timedelta
22
from unittest.mock import Mock
33

4+
import pytest
5+
46
from databricks.sdk import credentials_provider, oauth, oidc
7+
from databricks.sdk.config import Config
58

69

710
# Tests for external_browser function
@@ -174,6 +177,54 @@ def test_external_browser_consent_fails(mocker):
174177
assert got_credentials_provider is None
175178

176179

180+
def _setup_external_browser_mocks(mocker, cfg):
181+
"""Set up mocks for external_browser scope tests. Returns (TokenCache mock, OAuthClient mock)."""
182+
mock_oidc_endpoints = Mock()
183+
mock_oidc_endpoints.token_endpoint = "https://test.databricks.com/oidc/v1/token"
184+
mocker.patch.object(type(cfg), "oidc_endpoints", new_callable=lambda: property(lambda self: mock_oidc_endpoints))
185+
186+
mock_token_cache_class = mocker.patch("databricks.sdk.credentials_provider.oauth.TokenCache")
187+
mock_token_cache = Mock()
188+
mock_token_cache.load.return_value = None
189+
mock_token_cache_class.return_value = mock_token_cache
190+
191+
mock_oauth_client_class = mocker.patch("databricks.sdk.credentials_provider.oauth.OAuthClient")
192+
mock_oauth_client = Mock()
193+
mock_consent = Mock()
194+
mock_consent.launch_external_browser.return_value = Mock()
195+
mock_oauth_client.initiate_consent.return_value = mock_consent
196+
mock_oauth_client_class.return_value = mock_oauth_client
197+
198+
return mock_token_cache_class, mock_oauth_client_class
199+
200+
201+
@pytest.mark.parametrize(
202+
"scopes,disable_refresh,expected_scopes",
203+
[
204+
(None, False, ["all-apis", "offline_access"]),
205+
("sql, clusters, jobs", False, ["clusters", "jobs", "sql", "offline_access"]),
206+
(None, True, ["all-apis"]),
207+
("sql, clusters, jobs, offline_access", False, ["clusters", "jobs", "offline_access", "sql"]),
208+
],
209+
ids=["default_scopes", "multiple_scopes_sorted", "disable_offline_access", "offline_access_not_duplicated"],
210+
)
211+
def test_external_browser_scopes(mocker, scopes, disable_refresh, expected_scopes):
212+
"""Tests that external_browser passes correct scopes to TokenCache and OAuthClient."""
213+
mocker.patch("databricks.sdk.config.Config.init_auth")
214+
cfg = Config(
215+
host="https://test.databricks.com",
216+
auth_type="external-browser",
217+
scopes=scopes,
218+
disable_oauth_refresh_token=disable_refresh if disable_refresh else None,
219+
)
220+
mock_token_cache_class, mock_oauth_client_class = _setup_external_browser_mocks(mocker, cfg)
221+
222+
credentials_provider.external_browser(cfg)
223+
224+
assert mock_token_cache_class.call_args.kwargs["scopes"] == expected_scopes
225+
assert mock_oauth_client_class.call_args.kwargs["scopes"] == expected_scopes
226+
227+
177228
def test_oidc_credentials_provider_invalid_id_token_source():
178229
# Use a mock config object to avoid initializing the auth initialization.
179230
mock_cfg = Mock()

0 commit comments

Comments
 (0)