|
14 | 14 | from .credentials_provider import CredentialsStrategy, DefaultCredentials |
15 | 15 | from .environments import (ALL_ENVS, AzureEnvironment, Cloud, |
16 | 16 | DatabricksEnvironment, get_environment_for_hostname) |
17 | | -from .oauth import OidcEndpoints, Token |
| 17 | +from .oauth import (OidcEndpoints, Token, get_account_endpoints, |
| 18 | + get_azure_entra_id_workspace_endpoints, |
| 19 | + get_workspace_endpoints) |
| 20 | +from ._base_client import fix_host_if_needed |
18 | 21 |
|
19 | 22 | logger = logging.getLogger('databricks.sdk') |
20 | 23 |
|
@@ -118,7 +121,9 @@ def __init__(self, |
118 | 121 | self._set_inner_config(kwargs) |
119 | 122 | self._load_from_env() |
120 | 123 | self._known_file_config_loader() |
121 | | - self._fix_host_if_needed() |
| 124 | + updated_host = fix_host_if_needed(self.host) |
| 125 | + if updated_host: |
| 126 | + self.host = updated_host |
122 | 127 | self._validate() |
123 | 128 | self.init_auth() |
124 | 129 | self._init_product(product, product_version) |
@@ -250,28 +255,14 @@ def with_user_agent_extra(self, key: str, value: str) -> 'Config': |
250 | 255 |
|
251 | 256 | @property |
252 | 257 | def oidc_endpoints(self) -> Optional[OidcEndpoints]: |
253 | | - self._fix_host_if_needed() |
| 258 | + self.host = fix_host_if_needed(self.host) |
254 | 259 | if not self.host: |
255 | 260 | return None |
256 | 261 | if self.is_azure and self.azure_client_id: |
257 | | - # Retrieve authorize endpoint to retrieve token endpoint after |
258 | | - res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) |
259 | | - real_auth_url = res.headers.get('location') |
260 | | - if not real_auth_url: |
261 | | - return None |
262 | | - return OidcEndpoints(authorization_endpoint=real_auth_url, |
263 | | - token_endpoint=real_auth_url.replace('/authorize', '/token')) |
| 262 | + return get_azure_entra_id_workspace_endpoints(self.host) |
264 | 263 | if self.is_account_client and self.account_id: |
265 | | - prefix = f'{self.host}/oidc/accounts/{self.account_id}' |
266 | | - return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize', |
267 | | - token_endpoint=f'{prefix}/v1/token') |
268 | | - oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server' |
269 | | - res = requests.get(oidc) |
270 | | - if res.status_code != 200: |
271 | | - return None |
272 | | - auth_metadata = res.json() |
273 | | - return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'), |
274 | | - token_endpoint=auth_metadata.get('token_endpoint')) |
| 264 | + return get_account_endpoints(self.host, self.account_id) |
| 265 | + return get_workspace_endpoints(self.host) |
275 | 266 |
|
276 | 267 | def debug_string(self) -> str: |
277 | 268 | """ Returns log-friendly representation of configured attributes """ |
@@ -345,24 +336,6 @@ def attributes(cls) -> Iterable[ConfigAttribute]: |
345 | 336 | cls._attributes = attrs |
346 | 337 | return cls._attributes |
347 | 338 |
|
348 | | - def _fix_host_if_needed(self): |
349 | | - if not self.host: |
350 | | - return |
351 | | - |
352 | | - # Add a default scheme if it's missing |
353 | | - if '://' not in self.host: |
354 | | - self.host = 'https://' + self.host |
355 | | - |
356 | | - o = urllib.parse.urlparse(self.host) |
357 | | - # remove trailing slash |
358 | | - path = o.path.rstrip('/') |
359 | | - # remove port if 443 |
360 | | - netloc = o.netloc |
361 | | - if o.port == 443: |
362 | | - netloc = netloc.split(':')[0] |
363 | | - |
364 | | - self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) |
365 | | - |
366 | 339 | def load_azure_tenant_id(self): |
367 | 340 | """[Internal] Load the Azure tenant ID from the Azure Databricks login page. |
368 | 341 |
|
|
0 commit comments