Skip to content
Merged
20 changes: 20 additions & 0 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import urllib.parse
from datetime import timedelta
from types import TracebackType
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
Expand All @@ -17,6 +18,25 @@
logger = logging.getLogger('databricks.sdk')


def _fix_host_if_needed(host: Optional[str]) -> Optional[str]:
if not host:
return host

# Add a default scheme if it's missing
if '://' not in host:
host = 'https://' + host

o = urllib.parse.urlparse(host)
# remove trailing slash
path = o.path.rstrip('/')
# remove port if 443
netloc = o.netloc
if o.port == 443:
netloc = netloc.split(':')[0]

return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))


class _BaseClient:

def __init__(self,
Expand Down
44 changes: 10 additions & 34 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import requests

from . import useragent
from ._base_client import _fix_host_if_needed
from .clock import Clock, RealClock
from .credentials_provider import CredentialsStrategy, DefaultCredentials
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
DatabricksEnvironment, get_environment_for_hostname)
from .oauth import OidcEndpoints, Token
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
get_azure_entra_id_workspace_endpoints,
get_workspace_endpoints)

logger = logging.getLogger('databricks.sdk')

Expand Down Expand Up @@ -254,24 +257,10 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
if not self.host:
return None
if self.is_azure and self.azure_client_id:
# Retrieve authorize endpoint to retrieve token endpoint after
res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False)
real_auth_url = res.headers.get('location')
if not real_auth_url:
return None
return OidcEndpoints(authorization_endpoint=real_auth_url,
token_endpoint=real_auth_url.replace('/authorize', '/token'))
return get_azure_entra_id_workspace_endpoints(self.host)
if self.is_account_client and self.account_id:
prefix = f'{self.host}/oidc/accounts/{self.account_id}'
return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize',
token_endpoint=f'{prefix}/v1/token')
oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server'
res = requests.get(oidc)
if res.status_code != 200:
return None
auth_metadata = res.json()
return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'),
token_endpoint=auth_metadata.get('token_endpoint'))
return get_account_endpoints(self.host, self.account_id)
return get_workspace_endpoints(self.host)

def debug_string(self) -> str:
""" Returns log-friendly representation of configured attributes """
Expand Down Expand Up @@ -346,22 +335,9 @@ def attributes(cls) -> Iterable[ConfigAttribute]:
return cls._attributes

def _fix_host_if_needed(self):
if not self.host:
return

# Add a default scheme if it's missing
if '://' not in self.host:
self.host = 'https://' + self.host

o = urllib.parse.urlparse(self.host)
# remove trailing slash
path = o.path.rstrip('/')
# remove port if 443
netloc = o.netloc
if o.port == 443:
netloc = netloc.split(':')[0]

self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
updated_host = _fix_host_if_needed(self.host)
if updated_host:
self.host = updated_host

def load_azure_tenant_id(self):
"""[Internal] Load the Azure tenant ID from the Azure Databricks login page.
Expand Down
31 changes: 18 additions & 13 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,30 +187,35 @@ def token() -> Token:
def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
if cfg.auth_type != 'external-browser':
return None
client_id, client_secret = None, None
if cfg.client_id:
client_id = cfg.client_id
elif cfg.is_aws:
client_secret = cfg.client_secret
elif cfg.azure_client_id:
client_id = cfg.azure_client
client_secret = cfg.azure_client_secret

if not client_id:
client_id = 'databricks-cli'
elif cfg.is_azure:
# Use Azure AD app for cases when Azure CLI is not available on the machine.
# App has to be registered as Single-page multi-tenant to support PKCE
# TODO: temporary app ID, change it later.
client_id = '6128a518-99a9-425b-8333-4cc94f04cacd'
else:
raise ValueError(f'local browser SSO is not supported')
oauth_client = OAuthClient(host=cfg.host,
client_id=client_id,
redirect_url='http://localhost:8020',
client_secret=cfg.client_secret)

# Load cached credentials from disk if they exist.
# Note that these are local to the Python SDK and not reused by other SDKs.
token_cache = TokenCache(oauth_client)
oidc_endpoints = cfg.oidc_endpoints
redirect_url = 'http://localhost:8020'
token_cache = TokenCache(host=cfg.host,
oidc_endpoints=oidc_endpoints,
client_id=client_id,
client_secret=client_secret,
redirect_url=redirect_url)
credentials = token_cache.load()
if credentials:
# Force a refresh in case the loaded credentials are expired.
credentials.token()
else:
oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints,
client_id=client_id,
redirect_url=redirect_url,
client_secret=client_secret)
consent = oauth_client.initiate_consent()
if not consent:
return None
Expand Down
Loading
Loading