Skip to content

Commit 384834a

Browse files
committed
rebase
1 parent 201cc13 commit 384834a

File tree

6 files changed

+318
-155
lines changed

6 files changed

+318
-155
lines changed

databricks/sdk/_base_client.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,25 @@
1818
logger = logging.getLogger('databricks.sdk')
1919

2020

21+
def fix_host_if_needed(host: Optional[str]) -> Optional[str]:
22+
if not host:
23+
return host
24+
25+
# Add a default scheme if it's missing
26+
if '://' not in host:
27+
host = 'https://' + host
28+
29+
o = urllib.parse.urlparse(host)
30+
# remove trailing slash
31+
path = o.path.rstrip('/')
32+
# remove port if 443
33+
netloc = o.netloc
34+
if o.port == 443:
35+
netloc = netloc.split(':')[0]
36+
37+
return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
38+
39+
2140
class _BaseClient:
2241

2342
def __init__(self,

databricks/sdk/config.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from .credentials_provider import CredentialsStrategy, DefaultCredentials
1515
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
1616
DatabricksEnvironment, get_environment_for_hostname)
17-
from .oauth import OidcEndpoints, Token
17+
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
18+
get_azure_entra_id_workspace_endpoints,
19+
get_workspace_endpoints)
20+
from ._base_client import fix_host_if_needed
1821

1922
logger = logging.getLogger('databricks.sdk')
2023

@@ -118,7 +121,9 @@ def __init__(self,
118121
self._set_inner_config(kwargs)
119122
self._load_from_env()
120123
self._known_file_config_loader()
121-
self._fix_host_if_needed()
124+
updated_host = fix_host_if_needed(self.host)
125+
if updated_host:
126+
self.host = updated_host
122127
self._validate()
123128
self.init_auth()
124129
self._init_product(product, product_version)
@@ -250,28 +255,14 @@ def with_user_agent_extra(self, key: str, value: str) -> 'Config':
250255

251256
@property
252257
def oidc_endpoints(self) -> Optional[OidcEndpoints]:
253-
self._fix_host_if_needed()
258+
self.host = fix_host_if_needed(self.host)
254259
if not self.host:
255260
return None
256261
if self.is_azure and self.azure_client_id:
257-
# Retrieve authorize endpoint to retrieve token endpoint after
258-
res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False)
259-
real_auth_url = res.headers.get('location')
260-
if not real_auth_url:
261-
return None
262-
return OidcEndpoints(authorization_endpoint=real_auth_url,
263-
token_endpoint=real_auth_url.replace('/authorize', '/token'))
262+
return get_azure_entra_id_workspace_endpoints(self.host)
264263
if self.is_account_client and self.account_id:
265-
prefix = f'{self.host}/oidc/accounts/{self.account_id}'
266-
return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize',
267-
token_endpoint=f'{prefix}/v1/token')
268-
oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server'
269-
res = requests.get(oidc)
270-
if res.status_code != 200:
271-
return None
272-
auth_metadata = res.json()
273-
return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'),
274-
token_endpoint=auth_metadata.get('token_endpoint'))
264+
return get_account_endpoints(self.host, self.account_id)
265+
return get_workspace_endpoints(self.host)
275266

276267
def debug_string(self) -> str:
277268
""" Returns log-friendly representation of configured attributes """
@@ -345,24 +336,6 @@ def attributes(cls) -> Iterable[ConfigAttribute]:
345336
cls._attributes = attrs
346337
return cls._attributes
347338

348-
def _fix_host_if_needed(self):
349-
if not self.host:
350-
return
351-
352-
# Add a default scheme if it's missing
353-
if '://' not in self.host:
354-
self.host = 'https://' + self.host
355-
356-
o = urllib.parse.urlparse(self.host)
357-
# remove trailing slash
358-
path = o.path.rstrip('/')
359-
# remove port if 443
360-
netloc = o.netloc
361-
if o.port == 443:
362-
netloc = netloc.split(':')[0]
363-
364-
self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
365-
366339
def load_azure_tenant_id(self):
367340
"""[Internal] Load the Azure tenant ID from the Azure Databricks login page.
368341

databricks/sdk/credentials_provider.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,19 +197,24 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
197197
client_id = '6128a518-99a9-425b-8333-4cc94f04cacd'
198198
else:
199199
raise ValueError(f'local browser SSO is not supported')
200-
oauth_client = OAuthClient(host=cfg.host,
201-
client_id=client_id,
202-
redirect_url='http://localhost:8020',
203-
client_secret=cfg.client_secret)
204200

205201
# Load cached credentials from disk if they exist.
206202
# Note that these are local to the Python SDK and not reused by other SDKs.
207-
token_cache = TokenCache(oauth_client)
203+
oidc_endpoints = cfg.oidc_endpoints
204+
token_cache = TokenCache(host=cfg.host,
205+
oidc_endpoints=oidc_endpoints,
206+
client_id=client_id,
207+
client_secret=cfg.client_secret,
208+
redirect_url='http://localhost:8020')
208209
credentials = token_cache.load()
209210
if credentials:
210211
# Force a refresh in case the loaded credentials are expired.
211212
credentials.token()
212213
else:
214+
oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints,
215+
client_id=client_id,
216+
redirect_url='http://localhost:8020',
217+
client_secret=cfg.client_secret)
213218
consent = oauth_client.initiate_consent()
214219
if not consent:
215220
return None

0 commit comments

Comments
 (0)