diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index adc7e3612..650fd402d 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,10 @@ * Enabled asynchronous token refreshes by default. A new `disable_async_token_refresh` configuration option has been added to allow disabling this feature if necessary ([#952](https://github.com/databricks/databricks-sdk-py/pull/952)). To disable asynchronous token refresh, set the environment variable `DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH=true` or configure it within your configuration object. The previous `enable_experimental_async_token_refresh` option has been removed as asynchronous refresh is now the default behavior. +* Introduce support for Databricks Workload Identity Federation in GitHub workflows ([933](https://github.com/databricks/databricks-sdk-py/pull/933)). + See README.md for instructions. +* [Breaking] Users running their workflows in GitHub Actions, which use Cloud native authentication and also have a `DATABRICKS_CLIENT_ID` and `DATABRICKS_HOST` + environment variables set may see their authentication start failing due to the order in which the SDK tries different authentication methods. ### Bug Fixes diff --git a/README.md b/README.md index 9991c9cd0..58a885307 100644 --- a/README.md +++ b/README.md @@ -126,18 +126,18 @@ Depending on the Databricks authentication method, the SDK uses the following in ### Databricks native authentication -By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks basic (username/password) authentication (`auth_type="basic"` argument). +By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks Workload Identity Federation (WIF) authentication using OIDC (`auth_type="github-oidc"` argument). - For Databricks token authentication, you must provide `host` and `token`; or their environment variable or `.databrickscfg` file field equivalents. -- For Databricks basic authentication, you must provide `host`, `username`, and `password` _(for AWS workspace-level operations)_; or `host`, `account_id`, `username`, and `password` _(for AWS, Azure, or GCP account-level operations)_; or their environment variable or `.databrickscfg` file field equivalents. - -| Argument | Description | Environment variable | -|--------------|-------------|-------------------| -| `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` | -| `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` | -| `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` | -| `username` | _(String)_ The Databricks username part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_USERNAME` | -| `password` | _(String)_ The Databricks password part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_PASSWORD` | +- For Databricks OIDC authentication, you must provide the `host`, `client_id` and `token_audience` _(optional)_ either directly, through the corresponding environment variables, or in your `.databrickscfg` configuration file. + +| Argument | Description | Environment variable | +|------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------| +| `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` | +| `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` | +| `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` | +| `client_id` | _(String)_ The Databricks Service Principal Application ID. | `DATABRICKS_CLIENT_ID` | +| `token_audience` | _(String)_ When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier. | `TOKEN_AUDIENCE` | For example, to use Databricks token authentication: diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 5a4f9d75d..9189284dd 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -170,6 +170,7 @@ def __init__( product_version="0.0.0", credentials_strategy: Optional[CredentialsStrategy] = None, credentials_provider: Optional[CredentialsStrategy] = None, + token_audience: Optional[str] = None, config: Optional[client.Config] = None, ): if not config: @@ -198,6 +199,7 @@ def __init__( debug_headers=debug_headers, product=product, product_version=product_version, + token_audience=token_audience, ) self._config = config.copy() self._dbutils = _make_dbutils(self._config) @@ -862,6 +864,7 @@ def __init__( product_version="0.0.0", credentials_strategy: Optional[CredentialsStrategy] = None, credentials_provider: Optional[CredentialsStrategy] = None, + token_audience: Optional[str] = None, config: Optional[client.Config] = None, ): if not config: @@ -890,6 +893,7 @@ def __init__( debug_headers=debug_headers, product=product, product_version=product_version, + token_audience=token_audience, ) self._config = config.copy() self._api_client = client.ApiClient(self._config) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 7591e6896..1e674806f 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -61,6 +61,7 @@ class Config: host: str = ConfigAttribute(env="DATABRICKS_HOST") account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) + token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="github-oidc") username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic") password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True) client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth") diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index caf4c45f0..86bd5c4d2 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -23,6 +23,7 @@ from .azure import add_sp_management_token, add_workspace_id_header from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, TokenCache, TokenSource) +from .oidc_token_supplier import GitHubOIDCTokenSupplier CredentialsProvider = Callable[[], Dict[str, str]] @@ -314,6 +315,58 @@ def token() -> Token: return OAuthCredentialsProvider(refreshed_headers, token) +@oauth_credentials_strategy("github-oidc", ["host", "client_id"]) +def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]: + """ + DatabricksWIFCredentials uses a Token Supplier to get a JWT Token and exchanges + it for a Databricks Token. + + Supported suppliers: + - GitHub OIDC + """ + supplier = GitHubOIDCTokenSupplier() + + audience = cfg.token_audience + if audience is None and cfg.is_account_client: + audience = cfg.account_id + if audience is None and not cfg.is_account_client: + audience = cfg.oidc_endpoints.token_endpoint + + # Try to get an idToken. If no supplier returns a token, we cannot use this authentication mode. + id_token = supplier.get_oidc_token(audience) + if not id_token: + return None + + def token_source_for(audience: str) -> TokenSource: + id_token = supplier.get_oidc_token(audience) + if not id_token: + # Should not happen, since we checked it above. + raise Exception("Cannot get OIDC token") + params = { + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "subject_token": id_token, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + } + return ClientCredentials( + client_id=cfg.client_id, + client_secret="", # we have no (rotatable) secrets in OIDC flow + token_url=cfg.oidc_endpoints.token_endpoint, + endpoint_params=params, + scopes=["all-apis"], + use_params=True, + disable_async=cfg.disable_async_token_refresh, + ) + + def refreshed_headers() -> Dict[str, str]: + token = token_source_for(audience).token() + return {"Authorization": f"{token.token_type} {token.access_token}"} + + def token() -> Token: + return token_source_for(audience).token() + + return OAuthCredentialsProvider(refreshed_headers, token) + + @oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"]) def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ: @@ -325,16 +378,8 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: if not cfg.is_azure: return None - # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers - headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"} - endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange" - response = requests.get(endpoint, headers=headers) - if not response.ok: - return None - - # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name - response_json = response.json() - if "value" not in response_json: + token = GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange") + if not token: return None logger.info( @@ -344,7 +389,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: params = { "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", "resource": cfg.effective_azure_login_app_id, - "client_assertion": response_json["value"], + "client_assertion": token, } aad_endpoint = cfg.arm_environment.active_directory_endpoint if not cfg.azure_tenant_id: @@ -927,6 +972,7 @@ def __init__(self) -> None: basic_auth, metadata_service, oauth_service_principal, + databricks_wif, azure_service_principal, github_oidc_azure, azure_cli, diff --git a/databricks/sdk/oidc_token_supplier.py b/databricks/sdk/oidc_token_supplier.py new file mode 100644 index 000000000..dfd139de5 --- /dev/null +++ b/databricks/sdk/oidc_token_supplier.py @@ -0,0 +1,28 @@ +import os +from typing import Optional + +import requests + + +class GitHubOIDCTokenSupplier: + """ + Supplies OIDC tokens from GitHub Actions. + """ + + def get_oidc_token(self, audience: str) -> Optional[str]: + if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ or "ACTIONS_ID_TOKEN_REQUEST_URL" not in os.environ: + # not in GitHub actions + return None + # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers + headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"} + endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience={audience}" + response = requests.get(endpoint, headers=headers) + if not response.ok: + return None + + # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name + response_json = response.json() + if "value" not in response_json: + return None + + return response_json["value"] diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py index b50c54f1b..14aea59bf 100644 --- a/tests/integration/test_auth.py +++ b/tests/integration/test_auth.py @@ -12,6 +12,8 @@ import pytest +from databricks.sdk import AccountClient, WorkspaceClient +from databricks.sdk.service import iam, oauth2 from databricks.sdk.service.compute import (ClusterSpec, DataSecurityMode, Library, ResultType, SparkVersion) from databricks.sdk.service.jobs import NotebookTask, Task, ViewType @@ -198,3 +200,72 @@ def _task_outputs(w, run): output += data["data"] task_outputs[task_run.task_key] = output return task_outputs + + +def test_wif_account(ucacct, env_or_skip, random): + + sp = ucacct.service_principals.create( + active=True, + display_name="py-sdk-test-" + random(), + roles=[iam.ComplexValue(value="account_admin")], + ) + + ucacct.service_principal_federation_policy.create( + policy=oauth2.FederationPolicy( + oidc_policy=oauth2.OidcFederationPolicy( + issuer="https://token.actions.githubusercontent.com", + audiences=["https://github.com/databricks-eng"], + subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests", + ) + ), + service_principal_id=sp.id, + ) + + ac = AccountClient( + host=ucacct.config.host, + account_id=ucacct.config.account_id, + client_id=sp.application_id, + auth_type="github-oidc", + token_audience="https://github.com/databricks-eng", + ) + + groups = ac.groups.list() + + next(groups) + + +def test_wif_workspace(ucacct, env_or_skip, random): + + workspace_id = env_or_skip("TEST_WORKSPACE_ID") + workspace_url = env_or_skip("TEST_WORKSPACE_URL") + + sp = ucacct.service_principals.create( + active=True, + display_name="py-sdk-test-" + random(), + ) + + ucacct.service_principal_federation_policy.create( + policy=oauth2.FederationPolicy( + oidc_policy=oauth2.OidcFederationPolicy( + issuer="https://token.actions.githubusercontent.com", + audiences=["https://github.com/databricks-eng"], + subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests", + ) + ), + service_principal_id=sp.id, + ) + + ucacct.workspace_assignment.update( + workspace_id=workspace_id, + principal_id=sp.id, + permissions=[iam.WorkspacePermission.ADMIN], + ) + + ws = WorkspaceClient( + host=workspace_url, + client_id=sp.application_id, + auth_type="github-oidc", + token_audience="https://github.com/databricks-eng", + ) + + ws.current_user.me()