Skip to content

Commit e628a00

Browse files
committed
Add new credential provider
1 parent 6210fc8 commit e628a00

File tree

6 files changed

+292
-7
lines changed

6 files changed

+292
-7
lines changed

databricks/sdk/__init__.py

Lines changed: 11 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databricks/sdk/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class Config:
113113
disable_experimental_files_api_client: bool = ConfigAttribute(
114114
env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT"
115115
)
116+
scopes: str = ConfigAttribute()
117+
authorization_details: str = ConfigAttribute()
116118

117119
files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB
118120

databricks/sdk/credentials_provider.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,42 @@ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
176176
return None
177177

178178

179+
@oauth_credentials_strategy("runtime-oauth", ["scopes"])
180+
def runtime_oauth(cfg: "Config") -> Optional[CredentialsProvider]:
181+
if "DATABRICKS_RUNTIME_VERSION" not in os.environ:
182+
return None
183+
184+
def get_notebook_pat_token() -> Optional[str]:
185+
native_auth = runtime_native_auth(cfg)
186+
if native_auth is None:
187+
return None
188+
notebook_pat_token = None
189+
notebook_pat_authorization = native_auth().get("Authorization", "").strip()
190+
if notebook_pat_authorization.lower().startswith("bearer "):
191+
notebook_pat_token = notebook_pat_authorization[len("bearer ") :].strip()
192+
return notebook_pat_token
193+
194+
notebook_pat_token = get_notebook_pat_token()
195+
if notebook_pat_token is None:
196+
return None
197+
198+
token_source = oauth.PATOAuthTokenExchange(
199+
get_original_token=get_notebook_pat_token,
200+
host=cfg.host,
201+
scopes=cfg.scopes,
202+
authorization_details=cfg.authorization_details,
203+
)
204+
205+
def inner() -> Dict[str, str]:
206+
token = token_source.token()
207+
return {"Authorization": f"{token.token_type} {token.access_token}"}
208+
209+
def token() -> oauth.Token:
210+
return token_source.token()
211+
212+
return OAuthCredentialsProvider(inner, token)
213+
214+
179215
@oauth_credentials_strategy("oauth-m2m", ["host", "client_id", "client_secret"])
180216
def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
181217
"""Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
@@ -189,9 +225,10 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
189225
client_id=cfg.client_id,
190226
client_secret=cfg.client_secret,
191227
token_url=oidc.token_endpoint,
192-
scopes=["all-apis"],
228+
scopes=cfg.scopes or "all-apis",
193229
use_header=True,
194230
disable_async=cfg.disable_async_token_refresh,
231+
authorization_details=cfg.authorization_details,
195232
)
196233

197234
def inner() -> Dict[str, str]:
@@ -292,6 +329,8 @@ def token_source_for(resource: str) -> oauth.TokenSource:
292329
endpoint_params={"resource": resource},
293330
use_params=True,
294331
disable_async=cfg.disable_async_token_refresh,
332+
scopes=cfg.scopes,
333+
authorization_details=cfg.authorization_details,
295334
)
296335

297336
_ensure_host_present(cfg, token_source_for)
@@ -411,9 +450,10 @@ def token_source_for(audience: str) -> oauth.TokenSource:
411450
"subject_token": id_token,
412451
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
413452
},
414-
scopes=["all-apis"],
453+
scopes=cfg.scopes or "all-apis",
415454
use_params=True,
416455
disable_async=cfg.disable_async_token_refresh,
456+
authorization_details=cfg.authorization_details,
417457
)
418458

419459
def refreshed_headers() -> Dict[str, str]:
@@ -493,6 +533,8 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
493533
},
494534
use_params=True,
495535
disable_async=cfg.disable_async_token_refresh,
536+
scopes=cfg.scopes,
537+
authorization_details=cfg.authorization_details,
496538
)
497539

498540
def refreshed_headers() -> Dict[str, str]:
@@ -1070,6 +1112,7 @@ def __init__(self) -> None:
10701112
azure_devops_oidc,
10711113
external_browser,
10721114
databricks_cli,
1115+
runtime_oauth,
10731116
runtime_native_auth,
10741117
google_credentials,
10751118
google_id,

databricks/sdk/oauth.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from datetime import datetime, timedelta
1515
from enum import Enum
1616
from http.server import BaseHTTPRequestHandler, HTTPServer
17-
from typing import Any, Dict, List, Optional
17+
from typing import Any, Callable, Dict, List, Optional
1818

1919
import requests
2020
import requests.auth
@@ -32,6 +32,30 @@
3232
logger = logging.getLogger(__name__)
3333

3434

35+
@dataclass
36+
class AuthorizationDetail:
37+
type: str
38+
object_type: str
39+
object_path: str
40+
actions: list[str]
41+
42+
def as_dict(self) -> dict:
43+
return {
44+
"type": self.type,
45+
"object_type": self.object_type,
46+
"object_path": self.object_path,
47+
"actions": self.actions,
48+
}
49+
50+
def from_dict(self, d: dict) -> "AuthorizationDetail":
51+
return AuthorizationDetail(
52+
type=d.get("type"),
53+
object_type=d.get("object_type"),
54+
object_path=d.get("object_path"),
55+
actions=d.get("actions"),
56+
)
57+
58+
3559
class IgnoreNetrcAuth(requests.auth.AuthBase):
3660
"""This auth method is a no-op.
3761
@@ -706,18 +730,21 @@ class ClientCredentials(Refreshable):
706730
client_secret: str
707731
token_url: str
708732
endpoint_params: dict = None
709-
scopes: List[str] = None
733+
scopes: str = None
710734
use_params: bool = False
711735
use_header: bool = False
712736
disable_async: bool = True
737+
authorization_details: str = None
713738

714739
def __post_init__(self):
715740
super().__init__(disable_async=self.disable_async)
716741

717742
def refresh(self) -> Token:
718743
params = {"grant_type": "client_credentials"}
719744
if self.scopes:
720-
params["scope"] = " ".join(self.scopes)
745+
params["scope"] = self.scopes
746+
if self.authorization_details:
747+
params["authorization_details"] = self.authorization_details
721748
if self.endpoint_params:
722749
for k, v in self.endpoint_params.items():
723750
params[k] = v
@@ -731,6 +758,51 @@ def refresh(self) -> Token:
731758
)
732759

733760

761+
@dataclass
762+
class PATOAuthTokenExchange(Refreshable):
763+
get_original_token: Callable[[], Optional[str]]
764+
host: str
765+
scopes: str
766+
authorization_details: str = None
767+
disable_async: bool = True
768+
769+
def __post_init__(self):
770+
super().__init__(disable_async=self.disable_async)
771+
772+
def refresh(self) -> Token:
773+
token_exchange_url = f"{self.host}/oidc/v1/token"
774+
params = {
775+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
776+
"subject_token": self.get_original_token(),
777+
"subject_token_type": "urn:databricks:params:oauth:token-type:personal-access-token",
778+
"requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
779+
"scope": self.scopes,
780+
}
781+
if self.authorization_details:
782+
params["authorization_details"] = self.authorization_details
783+
784+
resp = requests.post(token_exchange_url, params)
785+
if not resp.ok:
786+
if resp.headers["Content-Type"].startswith("application/json"):
787+
err = resp.json()
788+
code = err.get("errorCode", err.get("error", "unknown"))
789+
summary = err.get("errorSummary", err.get("error_description", "unknown"))
790+
summary = summary.replace("\r\n", " ")
791+
raise ValueError(f"{code}: {summary}")
792+
raise ValueError(resp.content)
793+
try:
794+
j = resp.json()
795+
expires_in = int(j["expires_in"])
796+
expiry = datetime.now() + timedelta(seconds=expires_in)
797+
return Token(
798+
access_token=j["access_token"],
799+
expiry=expiry,
800+
token_type=j["token_type"],
801+
)
802+
except Exception as e:
803+
raise ValueError(f"Failed to exchange PAT for OAuth token: {e}")
804+
805+
734806
class TokenCache:
735807
BASE_PATH = "~/.config/databricks-sdk-py/oauth"
736808

databricks/sdk/oidc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _exchange_id_token(self, id_token: IdToken) -> oauth.Token:
202202
"subject_token": id_token.jwt,
203203
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
204204
},
205-
scopes=["all-apis"],
205+
scopes="all-apis",
206206
use_params=True,
207207
disable_async=self._disable_async,
208208
)

0 commit comments

Comments
 (0)