Skip to content

Commit 9b6ca63

Browse files
store scopes as list in config
1 parent 866bde4 commit 9b6ca63

File tree

4 files changed

+33
-42
lines changed

4 files changed

+33
-42
lines changed

databricks/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databricks/sdk/config.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
import os
66
import pathlib
7-
import re
87
import sys
98
import urllib.parse
109
from enum import Enum
@@ -47,11 +46,13 @@ class ConfigAttribute:
4746
# name and transform are discovered from Config.__new__
4847
name: str = None
4948
transform: type = str
49+
_custom_transform = None
5050

51-
def __init__(self, env: str = None, auth: str = None, sensitive: bool = False):
51+
def __init__(self, env: str = None, auth: str = None, sensitive: bool = False, transform=None):
5252
self.env = env
5353
self.auth = auth
5454
self.sensitive = sensitive
55+
self._custom_transform = transform
5556

5657
def __get__(self, cfg: "Config", owner):
5758
if not cfg:
@@ -65,6 +66,19 @@ def __repr__(self) -> str:
6566
return f"<ConfigAttribute '{self.name}' {self.transform.__name__}>"
6667

6768

69+
def _parse_scopes(value):
70+
"""Parse scopes into a deduplicated, sorted list."""
71+
if value is None:
72+
return None
73+
if isinstance(value, list):
74+
result = sorted(set(s for s in value if s))
75+
return result if result else None
76+
if isinstance(value, str):
77+
parsed = sorted(set(s.strip() for s in value.split(",") if s.strip()))
78+
return parsed if parsed else None
79+
return None
80+
81+
6882
def with_product(product: str, product_version: str):
6983
"""[INTERNAL API] Change the product name and version used in the User-Agent header."""
7084
useragent.with_product(product, product_version)
@@ -134,12 +148,12 @@ class Config:
134148
disable_experimental_files_api_client: bool = ConfigAttribute(
135149
env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT"
136150
)
137-
# TODO: Expose these via environment variables too.
138-
scopes: str = ConfigAttribute()
151+
152+
scopes: List[str] = ConfigAttribute(transform=_parse_scopes)
139153
authorization_details: str = ConfigAttribute()
140154

141-
# Controls whether the offline_access scope is requested during U2M OAuth authentication.
142-
# offline_access is requested by default, causing a refresh token to be included in the OAuth token.
155+
# disable_oauth_refresh_token controls whether a refresh token should be requested
156+
# during the U2M authentication flow (default to false).
143157
disable_oauth_refresh_token: bool = ConfigAttribute(env="DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN")
144158

145159
files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB
@@ -270,7 +284,6 @@ def __init__(
270284
self._known_file_config_loader()
271285
self._fix_host_if_needed()
272286
self._validate()
273-
self._sort_scopes()
274287
self.init_auth()
275288
self._init_product(product, product_version)
276289
except ValueError as e:
@@ -559,7 +572,7 @@ def attributes(cls) -> Iterable[ConfigAttribute]:
559572
if type(v) != ConfigAttribute:
560573
continue
561574
v.name = name
562-
v.transform = anno.get(name, str)
575+
v.transform = v._custom_transform if v._custom_transform else anno.get(name, str)
563576
attrs.append(v)
564577
cls._attributes = attrs
565578
return cls._attributes
@@ -672,16 +685,6 @@ def _validate(self):
672685
names = " and ".join(sorted(auths_used))
673686
raise ValueError(f"validate: more than one authorization method configured: {names}")
674687

675-
def _sort_scopes(self):
676-
"""Sort scopes in-place for better de-duplication in the refresh token cache.
677-
Delimiter is set to a single whitespace after sorting."""
678-
if self.scopes and isinstance(self.scopes, str):
679-
# Split on whitespaces and commas, sort, and rejoin
680-
parsed = [s for s in re.split(r"[\s,]+", self.scopes) if s]
681-
if parsed:
682-
parsed.sort()
683-
self.scopes = " ".join(parsed)
684-
685688
def init_auth(self):
686689
try:
687690
self._header_factory = self._credentials_strategy(self)
@@ -706,26 +709,14 @@ def get_scopes(self) -> List[str]:
706709
707710
Returns ["all-apis"] if no scopes configured.
708711
This is the single source of truth for scope defaulting across all OAuth methods.
709-
710-
Parses string scopes by splitting on whitespaces and commas.
711-
712-
Returns:
713-
List of scope strings.
714712
"""
715-
if self.scopes and isinstance(self.scopes, str):
716-
parsed = [s for s in re.split(r"[\s,]+", self.scopes) if s]
717-
if not parsed: # Empty string case
718-
return ["all-apis"]
719-
return parsed
720-
return ["all-apis"]
713+
return self.scopes if self.scopes else ["all-apis"]
721714

722715
def get_scopes_as_string(self) -> str:
723716
"""Get OAuth scopes as a space-separated string.
724717
725718
Returns "all-apis" if no scopes configured.
726719
"""
727-
if self.scopes and isinstance(self.scopes, str):
728-
return self.scopes
729720
return " ".join(self.get_scopes())
730721

731722
def __repr__(self):

databricks/sdk/credentials_provider.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def get_notebook_pat_token() -> Optional[str]:
198198
token_source = oauth.PATOAuthTokenExchange(
199199
get_original_token=get_notebook_pat_token,
200200
host=cfg.host,
201-
scopes=cfg.scopes,
201+
scopes=cfg.get_scopes_as_string(),
202202
authorization_details=cfg.authorization_details,
203203
)
204204

@@ -329,7 +329,7 @@ def token_source_for(resource: str) -> oauth.TokenSource:
329329
endpoint_params={"resource": resource},
330330
use_params=True,
331331
disable_async=cfg.disable_async_token_refresh,
332-
scopes=cfg.scopes,
332+
scopes=cfg.get_scopes_as_string(),
333333
authorization_details=cfg.authorization_details,
334334
)
335335

@@ -533,7 +533,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
533533
},
534534
use_params=True,
535535
disable_async=cfg.disable_async_token_refresh,
536-
scopes=cfg.scopes,
536+
scopes=cfg.get_scopes_as_string(),
537537
authorization_details=cfg.authorization_details,
538538
)
539539

tests/test_notebook_oauth.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ def credentials_provider() -> Dict[str, str]:
7878
@pytest.mark.parametrize(
7979
"scopes,auth_details",
8080
[
81-
("sql offline_access", None),
82-
("sql offline_access", '{"type": "databricks_resource"}'),
81+
("sql, offline_access", None),
82+
("sql, offline_access", '{"type": "databricks_resource"}'),
8383
("sql", None),
84-
("sql offline_access all-apis", None),
84+
("sql, offline_access, all-apis", None),
8585
],
8686
)
8787
def test_runtime_oauth_success_scenarios(
@@ -117,7 +117,7 @@ def test_runtime_oauth_missing_scopes(mock_runtime_env, mock_runtime_native_auth
117117

118118
def test_runtime_oauth_priority_over_native_auth(mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange):
119119
"""Test that runtime-oauth is prioritized over runtime-native-auth."""
120-
cfg = Config(host="https://test.cloud.databricks.com", scopes="sql offline_access")
120+
cfg = Config(host="https://test.cloud.databricks.com", scopes="sql, offline_access")
121121

122122
default_creds = DefaultCredentials()
123123
creds_provider = default_creds(cfg)
@@ -141,7 +141,7 @@ def test_fallback_to_native_auth_without_scopes(mock_runtime_env, mock_runtime_n
141141

142142
def test_explicit_runtime_oauth_auth_type(mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange):
143143
"""Test that runtime-oauth is used when explicitly specified as auth_type."""
144-
cfg = Config(host="https://test.cloud.databricks.com", scopes="sql offline_access", auth_type="runtime-oauth")
144+
cfg = Config(host="https://test.cloud.databricks.com", scopes="sql, offline_access", auth_type="runtime-oauth")
145145

146146
default_creds = DefaultCredentials()
147147
creds_provider = default_creds(cfg)
@@ -164,7 +164,7 @@ def test_config_authenticate_integration(
164164
"""Test Config.authenticate() integration with runtime-oauth and fallback."""
165165
cfg_kwargs = {"host": "https://test.cloud.databricks.com"}
166166
if has_scopes:
167-
cfg_kwargs["scopes"] = "sql offline_access"
167+
cfg_kwargs["scopes"] = "sql, offline_access"
168168

169169
cfg = Config(**cfg_kwargs)
170170
headers = cfg.authenticate()
@@ -174,7 +174,7 @@ def test_config_authenticate_integration(
174174

175175
@pytest.mark.parametrize(
176176
"scopes_input,expected_scopes",
177-
[(["sql", "offline_access"], "offline_access sql")],
177+
[(["sql", "offline_access"], ["offline_access", "sql"])],
178178
)
179179
def test_workspace_client_integration(
180180
mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, scopes_input, expected_scopes

0 commit comments

Comments
 (0)