Skip to content

Commit 62c4c81

Browse files
committed
some tests and an example
1 parent cfefa10 commit 62c4c81

File tree

3 files changed

+48
-27
lines changed

3 files changed

+48
-27
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,25 +187,24 @@ def token() -> Token:
187187
def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
188188
if cfg.auth_type != 'external-browser':
189189
return None
190+
client_id, client_secret = None, None
190191
if cfg.client_id:
191192
client_id = cfg.client_id
192-
elif cfg.is_aws:
193+
client_secret = cfg.client_secret
194+
elif cfg.azure_client_id:
195+
client_id = cfg.azure_client
196+
client_secret = cfg.azure_client_secret
197+
198+
if not client_id:
193199
client_id = 'databricks-cli'
194-
elif cfg.is_azure:
195-
# Use Azure AD app for cases when Azure CLI is not available on the machine.
196-
# App has to be registered as Single-page multi-tenant to support PKCE
197-
# TODO: temporary app ID, change it later.
198-
client_id = '6128a518-99a9-425b-8333-4cc94f04cacd'
199-
else:
200-
raise ValueError(f'local browser SSO is not supported')
201200

202201
# Load cached credentials from disk if they exist.
203202
# Note that these are local to the Python SDK and not reused by other SDKs.
204203
oidc_endpoints = cfg.oidc_endpoints
205204
token_cache = TokenCache(host=cfg.host,
206205
oidc_endpoints=oidc_endpoints,
207206
client_id=client_id,
208-
client_secret=cfg.client_secret,
207+
client_secret=client_secret,
209208
redirect_url='http://localhost:8020')
210209
credentials = token_cache.load()
211210
if credentials:
@@ -215,7 +214,7 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
215214
oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints,
216215
client_id=client_id,
217216
redirect_url='http://localhost:8020',
218-
client_secret=cfg.client_secret)
217+
client_secret=client_secret)
219218
consent = oauth_client.initiate_consent()
220219
if not consent:
221220
return None

databricks/sdk/oauth.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O
258258
:return: The workspace's OIDC endpoints.
259259
"""
260260
host = fix_host_if_needed(host)
261-
oidc = f'{host}/.well-known/oauth-authorization-server'
261+
oidc = f'{host}/oidc/.well-known/oauth-authorization-server'
262262
resp = client.do('GET', oidc)
263263
return OidcEndpoints.from_dict(resp)
264264

@@ -284,11 +284,11 @@ class SessionCredentials(Refreshable):
284284

285285
def __init__(self,
286286
token: Token,
287-
oidc_endpoints: OidcEndpoints,
287+
token_endpoint: str,
288288
client_id: str,
289289
client_secret: str = None,
290290
redirect_url: str = None):
291-
self._oidc_endpoints = oidc_endpoints
291+
self._token_endpoint = token_endpoint
292292
self._client_id = client_id
293293
self._client_secret = client_secret
294294
self._redirect_url = redirect_url
@@ -299,12 +299,12 @@ def as_dict(self) -> dict:
299299

300300
@staticmethod
301301
def from_dict(raw: dict,
302-
oidc_endpoints: OidcEndpoints,
302+
token_endpoint: str,
303303
client_id: str,
304304
client_secret: str = None,
305305
redirect_url: str = None) -> 'SessionCredentials':
306306
return SessionCredentials(token=Token.from_dict(raw['token']),
307-
oidc_endpoints=oidc_endpoints,
307+
token_endpoint=token_endpoint,
308308
client_id=client_id,
309309
client_secret=client_secret,
310310
redirect_url=redirect_url)
@@ -328,13 +328,13 @@ def refresh(self) -> Token:
328328
raise ValueError('oauth2: token expired and refresh token is not set')
329329
params = {'grant_type': 'refresh_token', 'refresh_token': refresh_token}
330330
headers = {}
331-
if 'microsoft' in self._oidc_endpoints.token_endpoint:
331+
if 'microsoft' in self._token_endpoint:
332332
# Tokens issued for the 'Single-Page Application' client-type may
333333
# only be redeemed via cross-origin requests
334334
headers = {'Origin': self._redirect_url}
335335
return retrieve_token(client_id=self._client_id,
336336
client_secret=self._client_secret,
337-
token_url=self._oidc_endpoints.token_endpoint,
337+
token_url=self._token_endpoint,
338338
params=params,
339339
use_params=True,
340340
headers=headers)
@@ -345,32 +345,36 @@ class Consent:
345345
def __init__(self,
346346
state: str,
347347
verifier: str,
348-
oidc_endpoints: OidcEndpoints,
348+
authorization_url: str,
349349
redirect_url: str,
350+
token_endpoint: str,
350351
client_id: str,
351352
client_secret: str = None) -> None:
352353
self._verifier = verifier
353354
self._state = state
354-
self._oidc_endpoints = oidc_endpoints
355+
self._authorization_url = authorization_url
355356
self._redirect_url = redirect_url
357+
self._token_endpoint = token_endpoint
356358
self._client_id = client_id
357359
self._client_secret = client_secret
358360

359361
def as_dict(self) -> dict:
360362
return {
361363
'state': self._state,
362364
'verifier': self._verifier,
365+
'authorization_url': self._authorization_url,
363366
'redirect_url': self._redirect_url,
364-
'oidc_endpoints': self._oidc_endpoints.as_dict(),
367+
'token_endpoint': self._token_endpoint,
365368
'client_id': self._client_id,
366369
}
367370

368371
@staticmethod
369372
def from_dict(raw: dict, client_secret: str = None) -> 'Consent':
370373
return Consent(raw['state'],
371374
raw['verifier'],
372-
oidc_endpoints=OidcEndpoints.from_dict(raw['oidc_endpoints']),
375+
authorization_url=raw['authorization_url'],
373376
redirect_url=raw['redirect_url'],
377+
token_endpoint=raw['token_endpoint'],
374378
client_id=raw['client_id'],
375379
client_secret=client_secret)
376380

@@ -379,8 +383,8 @@ def launch_external_browser(self) -> SessionCredentials:
379383
if redirect_url.hostname not in ('localhost', '127.0.0.1'):
380384
raise ValueError(f'cannot listen on {redirect_url.hostname}')
381385
feedback = []
382-
logger.info(f'Opening {self._oidc_endpoints.authorization_endpoint} in a browser')
383-
webbrowser.open_new(self._oidc_endpoints.authorization_endpoint)
386+
logger.info(f'Opening {self._authorization_url} in a browser')
387+
webbrowser.open_new(self._authorization_url)
384388
port = redirect_url.port
385389
handler_factory = functools.partial(_OAuthCallback, feedback)
386390
with HTTPServer(("localhost", port), handler_factory) as httpd:
@@ -412,11 +416,11 @@ def exchange(self, code: str, state: str) -> SessionCredentials:
412416
try:
413417
token = retrieve_token(client_id=self._client_id,
414418
client_secret=self._client_secret,
415-
token_url=self._oidc_endpoints.token_endpoint,
419+
token_url=self._token_endpoint,
416420
params=params,
417421
headers=headers,
418422
use_params=True)
419-
return SessionCredentials(token, self._oidc_endpoints, self._client_id, self._client_secret,
423+
return SessionCredentials(token, self._token_endpoint, self._client_id, self._client_secret,
420424
self._redirect_url)
421425
except ValueError as e:
422426
if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
@@ -481,11 +485,12 @@ def initiate_consent(self) -> Consent:
481485
'code_challenge': challenge,
482486
'code_challenge_method': 'S256'
483487
}
484-
f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}'
488+
auth_url = f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}'
485489
return Consent(state,
486490
verifier,
487-
oidc_endpoints=self._oidc_endpoints,
491+
authorization_url=auth_url,
488492
redirect_url=self.redirect_url,
493+
token_endpoint=self._oidc_endpoints.token_endpoint,
489494
client_id=self._client_id,
490495
client_secret=self._client_secret)
491496

examples/external_browser_auth.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from databricks.sdk import WorkspaceClient
2+
import logging
3+
4+
logging.basicConfig(level=logging.DEBUG)
5+
6+
7+
def run():
8+
w = WorkspaceClient(
9+
host=input("Enter Databricks host: "),
10+
auth_type="external-browser",
11+
)
12+
me = w.current_user.me()
13+
print(me)
14+
15+
16+
if __name__ == "__main__":
17+
run()

0 commit comments

Comments
 (0)