Skip to content

Commit b6f303c

Browse files
committed
Tests
1 parent da8eb26 commit b6f303c

File tree

6 files changed

+153
-12
lines changed

6 files changed

+153
-12
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ test:
2828
pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests
2929

3030
integration:
31-
pytest -n auto -m 'integration and not benchmark' --reruns 2 --dist loadgroup --cov=databricks --cov-report html tests
31+
pytest -n auto --dist loadgroup --cov=databricks --cov-report html tests/integration/test_auth.py
3232

3333
benchmark:
3434
pytest -m 'benchmark' tests

databricks/sdk/__init__.py

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databricks/sdk/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class Config:
6161
host: str = ConfigAttribute(env="DATABRICKS_HOST")
6262
account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID")
6363
token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True)
64+
token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="databricks-wif")
6465
username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic")
6566
password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True)
6667
client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth")

databricks/sdk/credentials_provider.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .azure import add_sp_management_token, add_workspace_id_header
2424
from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
2525
TokenCache, TokenSource)
26+
from .oidc_token_supplier import GitHubOIDCTokenSupplier
2627

2728
CredentialsProvider = Callable[[], Dict[str, str]]
2829

@@ -314,6 +315,44 @@ def token() -> Token:
314315
return OAuthCredentialsProvider(refreshed_headers, token)
315316

316317

318+
@oauth_credentials_strategy("databricks-wif", ["host", "client_id", "token_audience"])
319+
def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]:
320+
supplier = GitHubOIDCTokenSupplier()
321+
# Try to get a token. If no supplier returns a token, we cannot use this authentication mode.
322+
token = supplier.get_oidc_token(cfg.token_audience)
323+
if not token:
324+
return None
325+
326+
def token_source_for(audience: str) -> TokenSource:
327+
token = supplier.get_oidc_token(audience)
328+
if not token:
329+
# Should not happen, since we checked it above.
330+
raise Exception("Cannot get OIDC token")
331+
params = {
332+
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
333+
"subject_token": token,
334+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
335+
}
336+
return ClientCredentials(
337+
client_id=cfg.client_id,
338+
client_secret="", # we have no (rotatable) secrets in OIDC flow
339+
token_url=cfg.oidc_endpoints.token_endpoint,
340+
endpoint_params=params,
341+
scopes=["all-apis"],
342+
use_params=True,
343+
disable_async=not cfg.enable_experimental_async_token_refresh,
344+
)
345+
346+
def refreshed_headers() -> Dict[str, str]:
347+
token = token_source_for(cfg.token_audience).token()
348+
return {"Authorization": f"{token.token_type} {token.access_token}"}
349+
350+
def token() -> Token:
351+
return token_source_for(cfg.token_audience).token()
352+
353+
return OAuthCredentialsProvider(refreshed_headers, token)
354+
355+
317356
@oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
318357
def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
319358
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
@@ -325,16 +364,8 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
325364
if not cfg.is_azure:
326365
return None
327366

328-
# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
329-
headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
330-
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
331-
response = requests.get(endpoint, headers=headers)
332-
if not response.ok:
333-
return None
334-
335-
# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
336-
response_json = response.json()
337-
if "value" not in response_json:
367+
token = GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange")
368+
if not token:
338369
return None
339370

340371
logger.info(
@@ -344,7 +375,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
344375
params = {
345376
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
346377
"resource": cfg.effective_azure_login_app_id,
347-
"client_assertion": response_json["value"],
378+
"client_assertion": token,
348379
}
349380
aad_endpoint = cfg.arm_environment.active_directory_endpoint
350381
if not cfg.azure_tenant_id:
@@ -927,6 +958,7 @@ def __init__(self) -> None:
927958
basic_auth,
928959
metadata_service,
929960
oauth_service_principal,
961+
databricks_wif,
930962
azure_service_principal,
931963
github_oidc_azure,
932964
azure_cli,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
from time import sleep
3+
from typing import Optional
4+
5+
import requests
6+
7+
8+
class GitHubOIDCTokenSupplier:
9+
"""
10+
Supplies OIDC tokens from GitHub Actions.
11+
"""
12+
13+
def get_oidc_token(self, audience: str) -> Optional[str]:
14+
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ or "ACTIONS_ID_TOKEN_REQUEST_URL" not in os.environ:
15+
# not in GitHub actions
16+
return None
17+
# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
18+
headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
19+
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience={audience}"
20+
response = requests.get(endpoint, headers=headers)
21+
if not response.ok:
22+
return None
23+
24+
# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
25+
response_json = response.json()
26+
if "value" not in response_json:
27+
return None
28+
29+
# GitHub issued time is not allways in sync, and can give tokens which are not yet valid.
30+
# TODO: Remove this after Databricks API is updated to handle such cases.
31+
sleep(2)
32+
33+
return response_json["value"]

tests/integration/test_auth.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import pytest
1414

15+
from databricks.sdk import AccountClient, WorkspaceClient
16+
from databricks.sdk.service import iam, oauth2
1517
from databricks.sdk.service.compute import (ClusterSpec, DataSecurityMode,
1618
Library, ResultType, SparkVersion)
1719
from databricks.sdk.service.jobs import NotebookTask, Task, ViewType
@@ -198,3 +200,72 @@ def _task_outputs(w, run):
198200
output += data["data"]
199201
task_outputs[task_run.task_key] = output
200202
return task_outputs
203+
204+
205+
def test_wif_account(ucacct, env_or_skip, random):
206+
207+
sp = ucacct.service_principals.create(
208+
active=True,
209+
display_name="py-sdk-test-" + random(),
210+
roles=[iam.ComplexValue(value="account_admin")],
211+
)
212+
213+
ucacct.service_principal_federation_policy.create(
214+
policy=oauth2.FederationPolicy(
215+
oidc_policy=oauth2.OidcFederationPolicy(
216+
issuer="https://token.actions.githubusercontent.com",
217+
audiences=["https://github.com/databricks-eng"],
218+
subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests",
219+
)
220+
),
221+
service_principal_id=sp.id,
222+
)
223+
224+
ac = AccountClient(
225+
host=ucacct.config.host,
226+
account_id=ucacct.config.account_id,
227+
client_id=sp.application_id,
228+
auth_type="databricks-wif",
229+
token_audience="https://github.com/databricks-eng",
230+
)
231+
232+
groups = ac.groups.list()
233+
234+
next(groups)
235+
236+
237+
def test_wif_workspace(ucacct, env_or_skip, random):
238+
239+
workspace_id = env_or_skip("TEST_WORKSPACE_ID")
240+
workspace_url = env_or_skip("TEST_WORKSPACE_URL")
241+
242+
sp = ucacct.service_principals.create(
243+
active=True,
244+
display_name="py-sdk-test-" + random(),
245+
)
246+
247+
ucacct.service_principal_federation_policy.create(
248+
policy=oauth2.FederationPolicy(
249+
oidc_policy=oauth2.OidcFederationPolicy(
250+
issuer="https://token.actions.githubusercontent.com",
251+
audiences=["https://github.com/databricks-eng"],
252+
subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests",
253+
)
254+
),
255+
service_principal_id=sp.id,
256+
)
257+
258+
ucacct.workspace_assignment.update(
259+
workspace_id=workspace_id,
260+
principal_id=sp.id,
261+
permissions=[iam.WorkspacePermission.ADMIN],
262+
)
263+
264+
ws = WorkspaceClient(
265+
host=workspace_url,
266+
client_id=sp.application_id,
267+
auth_type="databricks-wif",
268+
token_audience="https://github.com/databricks-eng",
269+
)
270+
271+
ws.current_user.me()

0 commit comments

Comments
 (0)