Skip to content

Commit 3393d78

Browse files
committed
restructured code and tests
1 parent 3b5ab96 commit 3393d78

File tree

3 files changed

+188
-121
lines changed

3 files changed

+188
-121
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -99,37 +99,21 @@ def wrapper(cfg: "Config") -> Optional[CredentialsProvider]:
9999
return inner
100100

101101

102-
def oauth_credentials_strategy(name: str, require: List[str], env_vars: Optional[List[str]] = None):
102+
def oauth_credentials_strategy(name: str, require: List[str]):
103103
"""Given the function that receives a Config and returns an OauthHeaderFactory,
104104
create an OauthCredentialsProvider with a given name and required configuration
105105
attribute names to be present for this function to be called.
106106
107107
Args:
108108
name: The name of the authentication strategy
109109
require: List of config attributes that must be present
110-
env_vars: Optional list of environment variables that must all be present for this strategy
111110
"""
112111

113112
def inner(
114113
func: Callable[["Config"], OAuthCredentialsProvider],
115114
) -> OauthCredentialsStrategy:
116115
@functools.wraps(func)
117116
def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]:
118-
# Early environment detection - check before config validation
119-
if env_vars and not all(os.environ.get(var) for var in env_vars):
120-
# Provide specific error message for Azure DevOps OIDC SYSTEM_ACCESSTOKEN
121-
if (
122-
name == "azdo-oidc"
123-
and "SYSTEM_ACCESSTOKEN" in env_vars
124-
and not os.environ.get("SYSTEM_ACCESSTOKEN")
125-
):
126-
logger.debug(
127-
"Azure DevOps OIDC: SYSTEM_ACCESSTOKEN env var not found. If calling from Azure DevOps Pipeline, please set this env var following https://learn.microsoft.com/en-us/azure/devops/pipelines/build/variables?view=azure-devops&tabs=yaml#systemaccesstoken"
128-
)
129-
else:
130-
logger.debug(f"{name}: required environment variables not present, skipping")
131-
return None
132-
133117
for attr in require:
134118
if not getattr(cfg, attr):
135119
return None
@@ -428,26 +412,19 @@ def token() -> oauth.Token:
428412
return OAuthCredentialsProvider(refreshed_headers, token)
429413

430414

431-
@oauth_credentials_strategy(
432-
"azdo-oidc",
433-
["host", "client_id"],
434-
env_vars=[
435-
"SYSTEM_ACCESSTOKEN",
436-
"SYSTEM_TEAMFOUNDATIONCOLLECTIONURI",
437-
"SYSTEM_TEAMPROJECTID",
438-
"SYSTEM_PLANID",
439-
"SYSTEM_JOBID",
440-
"SYSTEM_HOSTTYPE",
441-
],
442-
)
415+
@oauth_credentials_strategy("azure-devops-oidc", ["host", "client_id"])
443416
def azure_devops_oidc(cfg: "Config") -> Optional[CredentialsProvider]:
444417
"""
445418
Azure DevOps OIDC authentication uses a Token Supplier to get a JWT Token
446419
and exchanges it for a Databricks Token.
447420
448421
Supported in Azure DevOps pipelines with OIDC service connections.
449422
"""
450-
supplier = oidc_token_supplier.AzureDevOpsOIDCTokenSupplier()
423+
try:
424+
supplier = oidc_token_supplier.AzureDevOpsOIDCTokenSupplier()
425+
except ValueError as e:
426+
logger.debug(str(e))
427+
return None
451428

452429
audience = cfg.token_audience
453430
if audience is None and cfg.is_account_client:

databricks/sdk/oidc_token_supplier.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,43 +39,62 @@ class AzureDevOpsOIDCTokenSupplier:
3939
See: https://docs.microsoft.com/en-us/azure/devops/pipelines/build/variables
4040
"""
4141

42+
def __init__(self):
43+
"""Initialize and validate Azure DevOps environment variables."""
44+
# Get Azure DevOps environment variables.
45+
self.access_token = os.environ.get("SYSTEM_ACCESSTOKEN")
46+
self.collection_uri = os.environ.get("SYSTEM_TEAMFOUNDATIONCOLLECTIONURI")
47+
self.project_id = os.environ.get("SYSTEM_TEAMPROJECTID")
48+
self.plan_id = os.environ.get("SYSTEM_PLANID")
49+
self.job_id = os.environ.get("SYSTEM_JOBID")
50+
self.hub_name = os.environ.get("SYSTEM_HOSTTYPE")
51+
52+
# Check for required variables with specific error messages.
53+
missing_vars = []
54+
if not self.access_token:
55+
missing_vars.append("SYSTEM_ACCESSTOKEN")
56+
if not self.collection_uri:
57+
missing_vars.append("SYSTEM_TEAMFOUNDATIONCOLLECTIONURI")
58+
if not self.project_id:
59+
missing_vars.append("SYSTEM_TEAMPROJECTID")
60+
if not self.plan_id:
61+
missing_vars.append("SYSTEM_PLANID")
62+
if not self.job_id:
63+
missing_vars.append("SYSTEM_JOBID")
64+
if not self.hub_name:
65+
missing_vars.append("SYSTEM_HOSTTYPE")
66+
67+
if missing_vars:
68+
if "SYSTEM_ACCESSTOKEN" in missing_vars:
69+
error_msg = "Azure DevOps OIDC: SYSTEM_ACCESSTOKEN env var not found. If calling from Azure DevOps Pipeline, please set this env var following https://learn.microsoft.com/en-us/azure/devops/pipelines/build/variables?view=azure-devops&tabs=yaml#systemaccesstoken"
70+
else:
71+
error_msg = f"Azure DevOps OIDC: missing required environment variables: {', '.join(missing_vars)}"
72+
raise ValueError(error_msg)
73+
4274
def get_oidc_token(self, audience: str) -> Optional[str]:
43-
# Note: Azure DevOps OIDC tokens have a fixed audience of "api://AzureADTokenExchange"
44-
# The audience parameter is ignored but kept for interface compatibility with other OIDC suppliers
45-
46-
access_token = os.environ.get("SYSTEM_ACCESSTOKEN")
47-
collection_uri = os.environ.get("SYSTEM_TEAMFOUNDATIONCOLLECTIONURI")
48-
project_id = os.environ.get("SYSTEM_TEAMPROJECTID")
49-
plan_id = os.environ.get("SYSTEM_PLANID")
50-
job_id = os.environ.get("SYSTEM_JOBID")
51-
hub_name = os.environ.get("SYSTEM_HOSTTYPE")
52-
53-
# Check for required variables
54-
if not all([access_token, collection_uri, project_id, plan_id, job_id, hub_name]):
55-
# not in Azure DevOps pipeline
56-
logger.debug("Azure DevOps OIDC: not in Azure DevOps pipeline environment")
57-
return None
75+
# Note: Azure DevOps OIDC tokens have a fixed audience of "api://AzureADTokenExchange".
76+
# The audience parameter is ignored but kept for interface compatibility with other OIDC suppliers.
5877

5978
try:
60-
# Construct the OIDC token request URL
61-
# Format: {collection_uri}{project_id}/_apis/distributedtask/hubs/{hubName}/plans/{planId}/jobs/{jobId}/oidctoken
62-
request_url = f"{collection_uri}{project_id}/_apis/distributedtask/hubs/{hub_name}/plans/{plan_id}/jobs/{job_id}/oidctoken"
79+
# Construct the OIDC token request URL.
80+
# Format: {collection_uri}{project_id}/_apis/distributedtask/hubs/{hubName}/plans/{planId}/jobs/{jobId}/oidctoken.
81+
request_url = f"{self.collection_uri}{self.project_id}/_apis/distributedtask/hubs/{self.hub_name}/plans/{self.plan_id}/jobs/{self.job_id}/oidctoken"
6382

64-
# Add API version (audience is fixed to "api://AzureADTokenExchange" by Azure DevOps)
83+
# Add API version (audience is fixed to "api://AzureADTokenExchange" by Azure DevOps).
6584
endpoint = f"{request_url}?api-version=7.2-preview.1"
6685
headers = {
67-
"Authorization": f"Bearer {access_token}",
86+
"Authorization": f"Bearer {self.access_token}",
6887
"Content-Type": "application/json",
6988
"Content-Length": "0",
7089
}
7190

72-
# Azure DevOps OIDC endpoint requires POST request with empty body
91+
# Azure DevOps OIDC endpoint requires POST request with empty body.
7392
response = requests.post(endpoint, headers=headers)
7493
if not response.ok:
7594
logger.debug(f"Azure DevOps OIDC: token request failed with status {response.status_code}")
7695
return None
7796

78-
# Azure DevOps returns the token in 'oidcToken' field
97+
# Azure DevOps returns the token in 'oidcToken' field.
7998
response_json = response.json()
8099
if "oidcToken" not in response_json:
81100
logger.debug("Azure DevOps OIDC: response missing 'oidcToken' field")

0 commit comments

Comments
 (0)