Skip to content

Commit 5063be3

Browse files
committed
Tests
1 parent da8eb26 commit 5063be3

File tree

8 files changed

+174
-22
lines changed

8 files changed

+174
-22
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ test:
2828
pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests
2929

3030
integration:
31-
pytest -n auto -m 'integration and not benchmark' --reruns 2 --dist loadgroup --cov=databricks --cov-report html tests
31+
pytest -n auto --dist loadgroup --cov=databricks --cov-report html tests/integration/test_auth.py
3232

3333
benchmark:
3434
pytest -m 'benchmark' tests

NEXT_CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
## Release v0.47.0
44

55
### New Features and Improvements
6+
* Introduce support for Databricks Workload Identity Federation in GitHub workflows ([933](https://github.com/databricks/databricks-sdk-py/pull/933)).
7+
See README.md for instructions.
8+
* [Breaking] Users running their workflows in GitHub Actions, which use Cloud native authentication and also have a `DATABRICKS_CLIENT_ID` and `DATABRICKS_HOST`
9+
environment variables set may see their authentication start failing due to the order in which the SDK tries different authentication methods.
610

711
### Bug Fixes
812

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,18 +126,18 @@ Depending on the Databricks authentication method, the SDK uses the following in
126126

127127
### Databricks native authentication
128128

129-
By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks basic (username/password) authentication (`auth_type="basic"` argument).
129+
By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks Workload Identity Federation (WIF) authentication (`auth_type="databricks-wif"` argument).
130130

131131
- For Databricks token authentication, you must provide `host` and `token`; or their environment variable or `.databrickscfg` file field equivalents.
132-
- For Databricks basic authentication, you must provide `host`, `username`, and `password` _(for AWS workspace-level operations)_; or `host`, `account_id`, `username`, and `password` _(for AWS, Azure, or GCP account-level operations)_; or their environment variable or `.databrickscfg` file field equivalents.
133-
134-
| Argument | Description | Environment variable |
135-
|--------------|-------------|-------------------|
136-
| `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` |
137-
| `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` |
138-
| `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` |
139-
| `username` | _(String)_ The Databricks username part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_USERNAME` |
140-
| `password` | _(String)_ The Databricks password part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_PASSWORD` |
132+
- For Databricks wif authentication, you must provide `host`, `client_id` and `token_audience` _(optional)_; or their environment variable or `.databrickscfg` file field equivalents.
133+
134+
| Argument | Description | Environment variable |
135+
|------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|
136+
| `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` |
137+
| `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` |
138+
| `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` |
139+
| `client_id` | _(String)_ The Databricks Service Principal Application ID. | `DATABRICKS_CLIENT_ID` |
140+
| `token_audience` | _(String)_ When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier. | `TOKEN_AUDIENCE` |
141141

142142
For example, to use Databricks token authentication:
143143

databricks/sdk/__init__.py

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databricks/sdk/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class Config:
6161
host: str = ConfigAttribute(env="DATABRICKS_HOST")
6262
account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID")
6363
token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True)
64+
token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="databricks-wif")
6465
username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic")
6566
password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True)
6667
client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth")

databricks/sdk/credentials_provider.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .azure import add_sp_management_token, add_workspace_id_header
2424
from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
2525
TokenCache, TokenSource)
26+
from .oidc_token_supplier import GitHubOIDCTokenSupplier
2627

2728
CredentialsProvider = Callable[[], Dict[str, str]]
2829

@@ -314,6 +315,51 @@ def token() -> Token:
314315
return OAuthCredentialsProvider(refreshed_headers, token)
315316

316317

318+
@oauth_credentials_strategy("databricks-wif", ["host", "client_id"])
319+
def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]:
320+
"""
321+
DatabricksWIFCredentials uses a Token Supplier to get a JWT Token and exchanges
322+
it for a Databricks Token.
323+
324+
Supported suppliers:
325+
- GitHub OIDC
326+
"""
327+
supplier = GitHubOIDCTokenSupplier()
328+
# Try to get an idToken. If no supplier returns a token, we cannot use this authentication mode.
329+
idToken = supplier.get_oidc_token(cfg.token_audience)
330+
if not idToken:
331+
return None
332+
333+
def token_source_for(audience: str) -> TokenSource:
334+
idToken = supplier.get_oidc_token(audience)
335+
if not idToken:
336+
# Should not happen, since we checked it above.
337+
raise Exception("Cannot get OIDC token")
338+
params = {
339+
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
340+
"subject_token": idToken,
341+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
342+
}
343+
return ClientCredentials(
344+
client_id=cfg.client_id,
345+
client_secret="", # we have no (rotatable) secrets in OIDC flow
346+
token_url=cfg.oidc_endpoints.token_endpoint,
347+
endpoint_params=params,
348+
scopes=["all-apis"],
349+
use_params=True,
350+
disable_async=not cfg.enable_experimental_async_token_refresh,
351+
)
352+
353+
def refreshed_headers() -> Dict[str, str]:
354+
token = token_source_for(cfg.token_audience).token()
355+
return {"Authorization": f"{token.token_type} {token.access_token}"}
356+
357+
def token() -> Token:
358+
return token_source_for(cfg.token_audience).token()
359+
360+
return OAuthCredentialsProvider(refreshed_headers, token)
361+
362+
317363
@oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
318364
def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
319365
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
@@ -325,16 +371,8 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
325371
if not cfg.is_azure:
326372
return None
327373

328-
# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
329-
headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
330-
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
331-
response = requests.get(endpoint, headers=headers)
332-
if not response.ok:
333-
return None
334-
335-
# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
336-
response_json = response.json()
337-
if "value" not in response_json:
374+
token = GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange")
375+
if not token:
338376
return None
339377

340378
logger.info(
@@ -344,7 +382,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
344382
params = {
345383
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
346384
"resource": cfg.effective_azure_login_app_id,
347-
"client_assertion": response_json["value"],
385+
"client_assertion": token,
348386
}
349387
aad_endpoint = cfg.arm_environment.active_directory_endpoint
350388
if not cfg.azure_tenant_id:
@@ -927,6 +965,7 @@ def __init__(self) -> None:
927965
basic_auth,
928966
metadata_service,
929967
oauth_service_principal,
968+
databricks_wif,
930969
azure_service_principal,
931970
github_oidc_azure,
932971
azure_cli,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
from time import sleep
3+
from typing import Optional
4+
5+
import requests
6+
7+
8+
class GitHubOIDCTokenSupplier:
9+
"""
10+
Supplies OIDC tokens from GitHub Actions.
11+
"""
12+
13+
def get_oidc_token(self, audience: str) -> Optional[str]:
14+
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ or "ACTIONS_ID_TOKEN_REQUEST_URL" not in os.environ:
15+
# not in GitHub actions
16+
return None
17+
# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
18+
headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
19+
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience={audience}"
20+
response = requests.get(endpoint, headers=headers)
21+
if not response.ok:
22+
return None
23+
24+
# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
25+
response_json = response.json()
26+
if "value" not in response_json:
27+
return None
28+
29+
# GitHub issued time is not allways in sync, and can give tokens which are not yet valid.
30+
# TODO: Remove this after Databricks API is updated to handle such cases.
31+
sleep(2)
32+
33+
return response_json["value"]

tests/integration/test_auth.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import pytest
1414

15+
from databricks.sdk import AccountClient, WorkspaceClient
16+
from databricks.sdk.service import iam, oauth2
1517
from databricks.sdk.service.compute import (ClusterSpec, DataSecurityMode,
1618
Library, ResultType, SparkVersion)
1719
from databricks.sdk.service.jobs import NotebookTask, Task, ViewType
@@ -198,3 +200,72 @@ def _task_outputs(w, run):
198200
output += data["data"]
199201
task_outputs[task_run.task_key] = output
200202
return task_outputs
203+
204+
205+
def test_wif_account(ucacct, env_or_skip, random):
206+
207+
sp = ucacct.service_principals.create(
208+
active=True,
209+
display_name="py-sdk-test-" + random(),
210+
roles=[iam.ComplexValue(value="account_admin")],
211+
)
212+
213+
ucacct.service_principal_federation_policy.create(
214+
policy=oauth2.FederationPolicy(
215+
oidc_policy=oauth2.OidcFederationPolicy(
216+
issuer="https://token.actions.githubusercontent.com",
217+
audiences=["https://github.com/databricks-eng"],
218+
subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests",
219+
)
220+
),
221+
service_principal_id=sp.id,
222+
)
223+
224+
ac = AccountClient(
225+
host=ucacct.config.host,
226+
account_id=ucacct.config.account_id,
227+
client_id=sp.application_id,
228+
auth_type="databricks-wif",
229+
token_audience="https://github.com/databricks-eng",
230+
)
231+
232+
groups = ac.groups.list()
233+
234+
next(groups)
235+
236+
237+
def test_wif_workspace(ucacct, env_or_skip, random):
238+
239+
workspace_id = env_or_skip("TEST_WORKSPACE_ID")
240+
workspace_url = env_or_skip("TEST_WORKSPACE_URL")
241+
242+
sp = ucacct.service_principals.create(
243+
active=True,
244+
display_name="py-sdk-test-" + random(),
245+
)
246+
247+
ucacct.service_principal_federation_policy.create(
248+
policy=oauth2.FederationPolicy(
249+
oidc_policy=oauth2.OidcFederationPolicy(
250+
issuer="https://token.actions.githubusercontent.com",
251+
audiences=["https://github.com/databricks-eng"],
252+
subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests",
253+
)
254+
),
255+
service_principal_id=sp.id,
256+
)
257+
258+
ucacct.workspace_assignment.update(
259+
workspace_id=workspace_id,
260+
principal_id=sp.id,
261+
permissions=[iam.WorkspacePermission.ADMIN],
262+
)
263+
264+
ws = WorkspaceClient(
265+
host=workspace_url,
266+
client_id=sp.application_id,
267+
auth_type="databricks-wif",
268+
token_audience="https://github.com/databricks-eng",
269+
)
270+
271+
ws.current_user.me()

0 commit comments

Comments
 (0)