Skip to content

Commit db1f4ae

Browse files
Create a method to generate OAuth tokens (#644)
## Changes Add method to get OAuth tokens ## Tests - [X] `make test` run locally - [X] `make fmt` applied - [ ] relevant integration tests applied - [X] Manual test (cannot be run as integration tests due to limitations in the current infrastructure setup) ``` def test(): w = WorkspaceClient(profile='DEFAULT') auth_details = f'"type":"workspace_permission","object_type":"serving-endpoints","object_path":"/serving-endpoints/REDACTED","actions":["query_inference_endpoint"]' auth_details = "[{" + auth_details + "}]" t = w.api_client.get_oauth_token(auth_details) print(t) ``` Result: ``` Token(access_token='REDACTED', token_type='Bearer', refresh_token=None, expiry=datetime.datetime(2024, 5, 16, 11, 9, 8, 221008)) ```
1 parent b13042b commit db1f4ae

File tree

14 files changed

+202
-82
lines changed

14 files changed

+202
-82
lines changed

.codegen/__init__.py.tmpl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import databricks.sdk.core as client
22
import databricks.sdk.dbutils as dbutils
3-
from databricks.sdk.credentials_provider import CredentialsProvider
3+
from databricks.sdk.credentials_provider import CredentialsStrategy
44

55
from databricks.sdk.mixins.files import DbfsExt
66
from databricks.sdk.mixins.compute import ClustersExt
@@ -46,10 +46,12 @@ class WorkspaceClient:
4646
debug_headers: bool = None,
4747
product="unknown",
4848
product_version="0.0.0",
49-
credentials_provider: CredentialsProvider = None,
49+
credentials_strategy: CredentialsStrategy = None,
50+
credentials_provider: CredentialsStrategy = None,
5051
config: client.Config = None):
5152
if not config:
5253
config = client.Config({{range $args}}{{.}}={{.}}, {{end}}
54+
credentials_strategy=credentials_strategy,
5355
credentials_provider=credentials_provider,
5456
debug_truncate_bytes=debug_truncate_bytes,
5557
debug_headers=debug_headers,
@@ -101,10 +103,12 @@ class AccountClient:
101103
debug_headers: bool = None,
102104
product="unknown",
103105
product_version="0.0.0",
104-
credentials_provider: CredentialsProvider = None,
106+
credentials_strategy: CredentialsStrategy = None,
107+
credentials_provider: CredentialsStrategy = None,
105108
config: client.Config = None):
106109
if not config:
107110
config = client.Config({{range $args}}{{.}}={{.}}, {{end}}
111+
credentials_strategy=credentials_strategy,
108112
credentials_provider=credentials_provider,
109113
debug_truncate_bytes=debug_truncate_bytes,
110114
debug_headers=debug_headers,

databricks/sdk/__init__.py

Lines changed: 7 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databricks/sdk/config.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import requests
1212

1313
from .clock import Clock, RealClock
14-
from .credentials_provider import CredentialsProvider, DefaultCredentials
14+
from .credentials_provider import CredentialsStrategy, DefaultCredentials
1515
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
1616
DatabricksEnvironment, get_environment_for_hostname)
17-
from .oauth import OidcEndpoints
17+
from .oauth import OidcEndpoints, Token
1818
from .version import __version__
1919

2020
logger = logging.getLogger('databricks.sdk')
@@ -81,15 +81,25 @@ class Config:
8181

8282
def __init__(self,
8383
*,
84-
credentials_provider: CredentialsProvider = None,
84+
# Deprecated. Use credentials_strategy instead.
85+
credentials_provider: CredentialsStrategy = None,
86+
credentials_strategy: CredentialsStrategy = None,
8587
product="unknown",
8688
product_version="0.0.0",
8789
clock: Clock = None,
8890
**kwargs):
8991
self._header_factory = None
9092
self._inner = {}
9193
self._user_agent_other_info = []
92-
self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials()
94+
if credentials_strategy and credentials_provider:
95+
raise ValueError(
96+
"When providing `credentials_strategy` field, `credential_provider` cannot be specified.")
97+
if credentials_provider:
98+
logger.warning(
99+
"parameter 'credentials_provider' is deprecated. Use 'credentials_strategy' instead.")
100+
self._credentials_strategy = next(
101+
s for s in [credentials_strategy, credentials_provider,
102+
DefaultCredentials()] if s is not None)
93103
if 'databricks_environment' in kwargs:
94104
self.databricks_environment = kwargs['databricks_environment']
95105
del kwargs['databricks_environment']
@@ -107,6 +117,9 @@ def __init__(self,
107117
message = self.wrap_debug_info(str(e))
108118
raise ValueError(message) from e
109119

120+
def oauth_token(self) -> Token:
121+
return self._credentials_strategy.oauth_token(self)
122+
110123
def wrap_debug_info(self, message: str) -> str:
111124
debug_string = self.debug_string()
112125
if debug_string:
@@ -436,12 +449,12 @@ def _validate(self):
436449

437450
def init_auth(self):
438451
try:
439-
self._header_factory = self._credentials_provider(self)
440-
self.auth_type = self._credentials_provider.auth_type()
452+
self._header_factory = self._credentials_strategy(self)
453+
self.auth_type = self._credentials_strategy.auth_type()
441454
if not self._header_factory:
442455
raise ValueError('not configured')
443456
except ValueError as e:
444-
raise ValueError(f'{self._credentials_provider.auth_type()} auth: {e}') from e
457+
raise ValueError(f'{self._credentials_strategy.auth_type()} auth: {e}') from e
445458

446459
def __repr__(self):
447460
return f'<{self.debug_string()}>'

databricks/sdk/core.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from json import JSONDecodeError
55
from types import TracebackType
66
from typing import Any, BinaryIO, Iterator, Type
7+
from urllib.parse import urlencode
78

89
from requests.adapters import HTTPAdapter
910

@@ -13,12 +14,17 @@
1314
from .credentials_provider import *
1415
from .errors import DatabricksError, error_mapper
1516
from .errors.private_link import _is_private_link_redirect
17+
from .oauth import retrieve_token
1618
from .retries import retried
1719

1820
__all__ = ['Config', 'DatabricksError']
1921

2022
logger = logging.getLogger('databricks.sdk')
2123

24+
URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
25+
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
26+
OIDC_TOKEN_PATH = "/oidc/v1/token"
27+
2228

2329
class ApiClient:
2430
_cfg: Config
@@ -109,6 +115,22 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
109115
flattened = dict(flatten_dict(with_fixed_bools))
110116
return flattened
111117

118+
def get_oauth_token(self, auth_details: str) -> Token:
119+
if not self._cfg.auth_type:
120+
self._cfg.authenticate()
121+
original_token = self._cfg.oauth_token()
122+
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
123+
params = urlencode({
124+
"grant_type": JWT_BEARER_GRANT_TYPE,
125+
"authorization_details": auth_details,
126+
"assertion": original_token.access_token
127+
})
128+
return retrieve_token(client_id=self._cfg.client_id,
129+
client_secret=self._cfg.client_secret,
130+
token_url=self._cfg.host + OIDC_TOKEN_PATH,
131+
params=params,
132+
headers=headers)
133+
112134
def do(self,
113135
method: str,
114136
path: str,

0 commit comments

Comments
 (0)