Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### New Features and Improvements

* Add native support for authentication through Azure DevOps OIDC

### Bug Fixes

### Documentation
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,11 @@ Depending on the Databricks authentication method, the SDK uses the following in

### Databricks native authentication

By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks Workload Identity Federation (WIF) authentication using OIDC (`auth_type="github-oidc"` argument).
By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Workload Identity Federation (WIF). See [Supported WIF](https://docs.databricks.com/aws/en/dev-tools/auth/oauth-federation-provider) for the supported JWT token providers.

- For Databricks token authentication, you must provide `host` and `token`; or their environment variable or `.databrickscfg` file field equivalents.
- For Databricks OIDC authentication, you must provide the `host`, `client_id` and `token_audience` _(optional)_ either directly, through the corresponding environment variables, or in your `.databrickscfg` configuration file.
- For Azure DevOps OIDC authentication, the `token_audience` is irrelevant as the audience is always set to `api://AzureADTokenExchange`. Also, the `System.AccessToken` pipeline variable required for OIDC request must be exposed as the `SYSTEM_ACCESSTOKEN` environment variable, following [Pipeline variables](https://learn.microsoft.com/en-us/azure/devops/pipelines/build/variables?view=azure-devops&tabs=yaml#systemaccesstoken)

| Argument | Description | Environment variable |
|------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|
Expand Down
73 changes: 61 additions & 12 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import threading
import time
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import google.auth # type: ignore
import requests
Expand Down Expand Up @@ -89,7 +89,6 @@ def inner(
@functools.wraps(func)
def wrapper(cfg: "Config") -> Optional[CredentialsProvider]:
for attr in require:
getattr(cfg, attr)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seemed like unnecessary Double check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? what does this function do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function checks if the attribute exists in the config or not. This check is redundant, as it is done in the next line as well

if not getattr(cfg, attr):
return None
return func(cfg)
Expand All @@ -103,7 +102,12 @@ def wrapper(cfg: "Config") -> Optional[CredentialsProvider]:
def oauth_credentials_strategy(name: str, require: List[str]):
"""Given the function that receives a Config and returns an OauthHeaderFactory,
create an OauthCredentialsProvider with a given name and required configuration
attribute names to be present for this function to be called."""
attribute names to be present for this function to be called.

Args:
name: The name of the authentication strategy
require: List of config attributes that must be present
"""

def inner(
func: Callable[["Config"], OAuthCredentialsProvider],
Expand Down Expand Up @@ -356,33 +360,47 @@ def token() -> oauth.Token:
return OAuthCredentialsProvider(refreshed_headers, token)


@oauth_credentials_strategy("github-oidc", ["host", "client_id"])
def github_oidc(cfg: "Config") -> Optional[CredentialsProvider]:
def _oidc_credentials_provider(
cfg: "Config", supplier_factory: Callable[[], Any], provider_name: str
) -> Optional[CredentialsProvider]:
"""
DatabricksWIFCredentials uses a Token Supplier to get a JWT Token and exchanges
it for a Databricks Token.
Generic OIDC credentials provider that works with any OIDC token supplier.

Args:
cfg: Databricks configuration
supplier_factory: Callable that returns an OIDC token supplier instance
provider_name: Human-readable name (e.g., "GitHub OIDC", "Azure DevOps OIDC")

Supported suppliers:
- GitHub OIDC
Returns:
OAuthCredentialsProvider if successful, None if supplier unavailable or token retrieval fails
"""
supplier = oidc_token_supplier.GitHubOIDCTokenSupplier()
# Try to create the supplier
try:
supplier = supplier_factory()
except Exception as e:
logger.debug(f"{provider_name}: {str(e)}")
return None

# Determine the audience for token exchange
audience = cfg.token_audience
if audience is None and cfg.is_account_client:
audience = cfg.account_id
if audience is None and not cfg.is_account_client:
audience = cfg.oidc_endpoints.token_endpoint

# Try to get an idToken. If no supplier returns a token, we cannot use this authentication mode.
# Try to get an OIDC token. If no supplier returns a token, we cannot use this authentication mode.
id_token = supplier.get_oidc_token(audience)
if not id_token:
logger.debug(f"{provider_name}: no token available, skipping authentication method")
return None

logger.info(f"Configured {provider_name} authentication")

def token_source_for(audience: str) -> oauth.TokenSource:
id_token = supplier.get_oidc_token(audience)
if not id_token:
# Should not happen, since we checked it above.
raise Exception("Cannot get OIDC token")
raise Exception(f"Cannot get {provider_name} token")

return oauth.ClientCredentials(
client_id=cfg.client_id,
Expand All @@ -408,6 +426,36 @@ def token() -> oauth.Token:
return OAuthCredentialsProvider(refreshed_headers, token)


@oauth_credentials_strategy("github-oidc", ["host", "client_id"])
def github_oidc(cfg: "Config") -> Optional[CredentialsProvider]:
"""
GitHub OIDC authentication uses a Token Supplier to get a JWT Token and exchanges
it for a Databricks Token.

Supported in GitHub Actions with OIDC service connections.
"""
return _oidc_credentials_provider(
cfg=cfg,
supplier_factory=lambda: oidc_token_supplier.GitHubOIDCTokenSupplier(),
provider_name="GitHub OIDC",
)


@oauth_credentials_strategy("azure-devops-oidc", ["host", "client_id"])
def azure_devops_oidc(cfg: "Config") -> Optional[CredentialsProvider]:
"""
Azure DevOps OIDC authentication uses a Token Supplier to get a JWT Token
and exchanges it for a Databricks Token.

Supported in Azure DevOps pipelines with OIDC service connections.
"""
return _oidc_credentials_provider(
cfg=cfg,
supplier_factory=lambda: oidc_token_supplier.AzureDevOpsOIDCTokenSupplier(),
provider_name="Azure DevOps OIDC",
)


@oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
Expand Down Expand Up @@ -1019,6 +1067,7 @@ def __init__(self) -> None:
azure_service_principal,
github_oidc_azure,
azure_cli,
azure_devops_oidc,
external_browser,
databricks_cli,
runtime_native_auth,
Expand Down
80 changes: 80 additions & 0 deletions databricks/sdk/oidc_token_supplier.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
import os
from typing import Optional

import requests

logger = logging.getLogger("databricks.sdk")


# TODO: Check the required environment variables while creating the instance rather than in the get_oidc_token method to allow early return.
class GitHubOIDCTokenSupplier:
"""
Supplies OIDC tokens from GitHub Actions.
Expand All @@ -26,3 +30,79 @@ def get_oidc_token(self, audience: str) -> Optional[str]:
return None

return response_json["value"]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same: GitHubOIDC does not validate on create. Is there a reasons to change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done to ensure Early exit. This is similar to what is done in Go SDK. If the environment variables are not set then we are sure that we are not in Azure DevOps Environment so we should exit at the earliest and try other providers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a TODO on the GitHub OIDC implementation to make it clear that it is the one with tech debt?


class AzureDevOpsOIDCTokenSupplier:
"""
Supplies OIDC tokens from Azure DevOps pipelines.

Constructs the OIDC token request URL using official Azure DevOps predefined variables.
See: https://docs.microsoft.com/en-us/azure/devops/pipelines/build/variables
"""

def __init__(self):
"""Initialize and validate Azure DevOps environment variables."""
# Get Azure DevOps environment variables.
self.access_token = os.environ.get("SYSTEM_ACCESSTOKEN")
self.collection_uri = os.environ.get("SYSTEM_TEAMFOUNDATIONCOLLECTIONURI")
self.project_id = os.environ.get("SYSTEM_TEAMPROJECTID")
self.plan_id = os.environ.get("SYSTEM_PLANID")
self.job_id = os.environ.get("SYSTEM_JOBID")
self.hub_name = os.environ.get("SYSTEM_HOSTTYPE")

# Check for required variables with specific error messages.
missing_vars = []
if not self.access_token:
missing_vars.append("SYSTEM_ACCESSTOKEN")
if not self.collection_uri:
missing_vars.append("SYSTEM_TEAMFOUNDATIONCOLLECTIONURI")
if not self.project_id:
missing_vars.append("SYSTEM_TEAMPROJECTID")
if not self.plan_id:
missing_vars.append("SYSTEM_PLANID")
if not self.job_id:
missing_vars.append("SYSTEM_JOBID")
if not self.hub_name:
missing_vars.append("SYSTEM_HOSTTYPE")

if missing_vars:
if "SYSTEM_ACCESSTOKEN" in missing_vars:
error_msg = "Azure DevOps OIDC: SYSTEM_ACCESSTOKEN env var not found. If calling from Azure DevOps Pipeline, please set this env var following https://learn.microsoft.com/en-us/azure/devops/pipelines/build/variables?view=azure-devops&tabs=yaml#systemaccesstoken"
else:
error_msg = f"Azure DevOps OIDC: missing required environment variables: {', '.join(missing_vars)}"
raise ValueError(error_msg)

def get_oidc_token(self, audience: str) -> Optional[str]:
# Note: Azure DevOps OIDC tokens have a fixed audience of "api://AzureADTokenExchange".
# The audience parameter is ignored but kept for interface compatibility with other OIDC suppliers.

try:
# Construct the OIDC token request URL.
# Format: {collection_uri}{project_id}/_apis/distributedtask/hubs/{hubName}/plans/{planId}/jobs/{jobId}/oidctoken.
request_url = f"{self.collection_uri}{self.project_id}/_apis/distributedtask/hubs/{self.hub_name}/plans/{self.plan_id}/jobs/{self.job_id}/oidctoken"

# Add API version (audience is fixed to "api://AzureADTokenExchange" by Azure DevOps).
endpoint = f"{request_url}?api-version=7.2-preview.1"
headers = {
"Authorization": f"Bearer {self.access_token}",
"Content-Type": "application/json",
"Content-Length": "0",
}

# Azure DevOps OIDC endpoint requires POST request with empty body.
response = requests.post(endpoint, headers=headers)
if not response.ok:
logger.debug(f"Azure DevOps OIDC: token request failed with status {response.status_code}")
return None

# Azure DevOps returns the token in 'oidcToken' field.
response_json = response.json()
if "oidcToken" not in response_json:
logger.debug("Azure DevOps OIDC: response missing 'oidcToken' field")
return None

logger.debug("Azure DevOps OIDC: successfully obtained token")
return response_json["oidcToken"]
except Exception as e:
logger.debug(f"Azure DevOps OIDC: failed to get token: {e}")
return None
Loading
Loading