Skip to content

Commit 3151fe5

Browse files
authored
Merge branch 'main' into yuanjie/fix-delta-sharing-download
2 parents afc095c + 2df6609 commit 3151fe5

File tree

7 files changed

+341
-7
lines changed

7 files changed

+341
-7
lines changed

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Release v0.74.0
44

55
### New Features and Improvements
6+
* Add new auth type (`runtime-oauth`) for notebooks: Introduce a new authentication mechanism that allows notebooks to authenticate using OAuth tokens
67

78
### Security
89

databricks/sdk/__init__.py

Lines changed: 11 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databricks/sdk/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ class Config:
113113
disable_experimental_files_api_client: bool = ConfigAttribute(
114114
env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT"
115115
)
116+
# TODO: Expose these via environment variables too.
117+
scopes: str = ConfigAttribute()
118+
authorization_details: str = ConfigAttribute()
116119

117120
files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB
118121

databricks/sdk/credentials_provider.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,42 @@ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
176176
return None
177177

178178

179+
@oauth_credentials_strategy("runtime-oauth", ["scopes"])
180+
def runtime_oauth(cfg: "Config") -> Optional[CredentialsProvider]:
181+
if "DATABRICKS_RUNTIME_VERSION" not in os.environ:
182+
return None
183+
184+
def get_notebook_pat_token() -> Optional[str]:
185+
native_auth = runtime_native_auth(cfg)
186+
if native_auth is None:
187+
return None
188+
notebook_pat_token = None
189+
notebook_pat_authorization = native_auth().get("Authorization", "").strip()
190+
if notebook_pat_authorization.lower().startswith("bearer "):
191+
notebook_pat_token = notebook_pat_authorization[len("bearer ") :].strip()
192+
return notebook_pat_token
193+
194+
notebook_pat_token = get_notebook_pat_token()
195+
if notebook_pat_token is None:
196+
return None
197+
198+
token_source = oauth.PATOAuthTokenExchange(
199+
get_original_token=get_notebook_pat_token,
200+
host=cfg.host,
201+
scopes=cfg.scopes,
202+
authorization_details=cfg.authorization_details,
203+
)
204+
205+
def inner() -> Dict[str, str]:
206+
token = token_source.token()
207+
return {"Authorization": f"{token.token_type} {token.access_token}"}
208+
209+
def token() -> oauth.Token:
210+
return token_source.token()
211+
212+
return OAuthCredentialsProvider(inner, token)
213+
214+
179215
@oauth_credentials_strategy("oauth-m2m", ["host", "client_id", "client_secret"])
180216
def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
181217
"""Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
@@ -189,9 +225,10 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
189225
client_id=cfg.client_id,
190226
client_secret=cfg.client_secret,
191227
token_url=oidc.token_endpoint,
192-
scopes=["all-apis"],
228+
scopes=cfg.scopes or "all-apis",
193229
use_header=True,
194230
disable_async=cfg.disable_async_token_refresh,
231+
authorization_details=cfg.authorization_details,
195232
)
196233

197234
def inner() -> Dict[str, str]:
@@ -292,6 +329,8 @@ def token_source_for(resource: str) -> oauth.TokenSource:
292329
endpoint_params={"resource": resource},
293330
use_params=True,
294331
disable_async=cfg.disable_async_token_refresh,
332+
scopes=cfg.scopes,
333+
authorization_details=cfg.authorization_details,
295334
)
296335

297336
_ensure_host_present(cfg, token_source_for)
@@ -411,9 +450,10 @@ def token_source_for(audience: str) -> oauth.TokenSource:
411450
"subject_token": id_token,
412451
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
413452
},
414-
scopes=["all-apis"],
453+
scopes=cfg.scopes or "all-apis",
415454
use_params=True,
416455
disable_async=cfg.disable_async_token_refresh,
456+
authorization_details=cfg.authorization_details,
417457
)
418458

419459
def refreshed_headers() -> Dict[str, str]:
@@ -493,6 +533,8 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
493533
},
494534
use_params=True,
495535
disable_async=cfg.disable_async_token_refresh,
536+
scopes=cfg.scopes,
537+
authorization_details=cfg.authorization_details,
496538
)
497539

498540
def refreshed_headers() -> Dict[str, str]:
@@ -1070,6 +1112,7 @@ def __init__(self) -> None:
10701112
azure_devops_oidc,
10711113
external_browser,
10721114
databricks_cli,
1115+
runtime_oauth,
10731116
runtime_native_auth,
10741117
google_credentials,
10751118
google_id,

databricks/sdk/oauth.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from datetime import datetime, timedelta
1515
from enum import Enum
1616
from http.server import BaseHTTPRequestHandler, HTTPServer
17-
from typing import Any, Dict, List, Optional
17+
from typing import Any, Callable, Dict, List, Optional
1818

1919
import requests
2020
import requests.auth
@@ -32,6 +32,30 @@
3232
logger = logging.getLogger(__name__)
3333

3434

35+
@dataclass
36+
class AuthorizationDetail:
37+
type: str
38+
object_type: str
39+
object_path: str
40+
actions: List[str]
41+
42+
def as_dict(self) -> dict:
43+
return {
44+
"type": self.type,
45+
"object_type": self.object_type,
46+
"object_path": self.object_path,
47+
"actions": self.actions,
48+
}
49+
50+
def from_dict(self, d: dict) -> "AuthorizationDetail":
51+
return AuthorizationDetail(
52+
type=d.get("type"),
53+
object_type=d.get("object_type"),
54+
object_path=d.get("object_path"),
55+
actions=d.get("actions"),
56+
)
57+
58+
3559
class IgnoreNetrcAuth(requests.auth.AuthBase):
3660
"""This auth method is a no-op.
3761
@@ -706,18 +730,21 @@ class ClientCredentials(Refreshable):
706730
client_secret: str
707731
token_url: str
708732
endpoint_params: dict = None
709-
scopes: List[str] = None
733+
scopes: str = None
710734
use_params: bool = False
711735
use_header: bool = False
712736
disable_async: bool = True
737+
authorization_details: str = None
713738

714739
def __post_init__(self):
715740
super().__init__(disable_async=self.disable_async)
716741

717742
def refresh(self) -> Token:
718743
params = {"grant_type": "client_credentials"}
719744
if self.scopes:
720-
params["scope"] = " ".join(self.scopes)
745+
params["scope"] = self.scopes
746+
if self.authorization_details:
747+
params["authorization_details"] = self.authorization_details
721748
if self.endpoint_params:
722749
for k, v in self.endpoint_params.items():
723750
params[k] = v
@@ -731,6 +758,67 @@ def refresh(self) -> Token:
731758
)
732759

733760

761+
@dataclass
762+
class PATOAuthTokenExchange(Refreshable):
763+
"""Performs OAuth token exchange using a Personal Access Token (PAT) as the subject token.
764+
765+
This class implements the OAuth 2.0 Token Exchange flow (RFC 8693) to exchange a Databricks
766+
Internal PAT Token for an access token with specific scopes and authorization details.
767+
768+
Args:
769+
get_original_token: A callable that returns the PAT to be exchanged. This is a callable
770+
rather than a string value to ensure that a fresh Internal PAT Token is retrieved
771+
at the time of refresh.
772+
host: The Databricks workspace URL (e.g., "https://my-workspace.cloud.databricks.com").
773+
scopes: Space-delimited string of OAuth scopes to request (e.g., "all-apis offline_access").
774+
authorization_details: Optional JSON string containing authorization details as defined in
775+
AuthorizationDetail class above.
776+
disable_async: Whether to disable asynchronous token refresh. Defaults to True.
777+
"""
778+
779+
get_original_token: Callable[[], Optional[str]]
780+
host: str
781+
scopes: str
782+
authorization_details: str = None
783+
disable_async: bool = True
784+
785+
def __post_init__(self):
786+
super().__init__(disable_async=self.disable_async)
787+
788+
def refresh(self) -> Token:
789+
token_exchange_url = f"{self.host}/oidc/v1/token"
790+
params = {
791+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
792+
"subject_token": self.get_original_token(),
793+
"subject_token_type": "urn:databricks:params:oauth:token-type:personal-access-token",
794+
"requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
795+
"scope": self.scopes,
796+
}
797+
if self.authorization_details:
798+
params["authorization_details"] = self.authorization_details
799+
800+
resp = requests.post(token_exchange_url, params)
801+
if not resp.ok:
802+
if resp.headers["Content-Type"].startswith("application/json"):
803+
err = resp.json()
804+
code = err.get("errorCode", err.get("error", "unknown"))
805+
summary = err.get("errorSummary", err.get("error_description", "unknown"))
806+
summary = summary.replace("\r\n", " ")
807+
raise ValueError(f"{code}: {summary}")
808+
raise ValueError(resp.content)
809+
try:
810+
j = resp.json()
811+
expires_in = int(j["expires_in"])
812+
expiry = datetime.now() + timedelta(seconds=expires_in)
813+
return Token(
814+
access_token=j["access_token"],
815+
expiry=expiry,
816+
token_type=j["token_type"],
817+
)
818+
except Exception as e:
819+
raise ValueError(f"Failed to exchange PAT for OAuth token: {e}")
820+
821+
734822
class TokenCache:
735823
BASE_PATH = "~/.config/databricks-sdk-py/oauth"
736824

databricks/sdk/oidc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _exchange_id_token(self, id_token: IdToken) -> oauth.Token:
202202
"subject_token": id_token.jwt,
203203
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
204204
},
205-
scopes=["all-apis"],
205+
scopes="all-apis",
206206
use_params=True,
207207
disable_async=self._disable_async,
208208
)

0 commit comments

Comments
 (0)