Skip to content

Commit b0750eb

Browse files
authored
[Fix] Infer Azure tenant ID if not set (#638)
## Changes Port of databricks/databricks-sdk-go#910 to the Python SDK. In order to use Azure U2M or M2M authentication with the Databricks SDK, users must request a token from the Entra ID instance that the underlying workspace or account belongs to, as Databricks rejects requests to workspaces with a token from a different Entra ID tenant. However, with Azure CLI auth, it is possible that a user is logged into multiple tenants at the same time. Currently, the SDK uses the subscription ID from the configured Azure Resource ID for the workspace when issuing the `az account get-access-token` command. However, when users don't specify the resource ID, the SDK simply fetches a token for the active subscription for the user. If the active subscription is in a different tenant than the workspace, users will see an error such as: ``` io.jsonwebtoken.IncorrectClaimException: Expected iss claim to be: https://sts.windows.net/72f988bf-86f1-41af-91ab-2d7cd011db47/, but was: https://sts.windows.net/e3fe3f22-4b98-4c04-82cc-d8817d1b17da/ ``` This PR modifies Azure CLI and Azure SP credential providers to attempt to load the tenant ID of the workspace if not provided before authenticating. Currently, there are no unauthenticated endpoints that the tenant ID can be directly fetched from. However, the tenant ID is indirectly exposed via the redirect URL used when logging into a workspace. In this PR, we fetch the tenant ID from this endpoint and configure it if not already set. Here, we lazily fetch the tenant ID only in the auth methods that need it. This prevents us from making any unnecessary requests if these Azure credential providers are not needed. ## Tests Unit tests check that the tenant ID is fetched automatically if not specified for an azure workspace when authenticating with client ID/secret or with the CLI. - [x] `make test` run locally - [x] `make fmt` applied - [x] relevant integration tests applied
1 parent f5c5f48 commit b0750eb

File tree

6 files changed

+112
-24
lines changed

6 files changed

+112
-24
lines changed

databricks/sdk/config.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,33 @@ def _fix_host_if_needed(self):
363363

364364
self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
365365

366+
def load_azure_tenant_id(self):
367+
"""[Internal] Load the Azure tenant ID from the Azure Databricks login page.
368+
369+
If the tenant ID is already set, this method does nothing."""
370+
if not self.is_azure or self.azure_tenant_id is not None or self.host is None:
371+
return
372+
login_url = f'{self.host}/aad/auth'
373+
logger.debug(f'Loading tenant ID from {login_url}')
374+
resp = requests.get(login_url, allow_redirects=False)
375+
if resp.status_code // 100 != 3:
376+
logger.debug(
377+
f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}')
378+
return
379+
entra_id_endpoint = resp.headers.get('Location')
380+
if entra_id_endpoint is None:
381+
logger.debug(f'No Location header in response from {login_url}')
382+
return
383+
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
384+
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
385+
url = urllib.parse.urlparse(entra_id_endpoint)
386+
path_segments = url.path.split('/')
387+
if len(path_segments) < 2:
388+
logger.debug(f'Invalid path in Location header: {url.path}')
389+
return
390+
self.azure_tenant_id = path_segments[1]
391+
logger.debug(f'Loaded tenant ID: {self.azure_tenant_id}')
392+
366393
def _set_inner_config(self, keyword_args: Dict[str, any]):
367394
for attr in self.attributes():
368395
if attr.name not in keyword_args:

databricks/sdk/credentials_provider.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,7 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS
233233
cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
234234

235235

236-
@oauth_credentials_strategy('azure-client-secret',
237-
['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
236+
@oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret'])
238237
def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
239238
""" Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
240239
to every request, while automatically resolving different Azure environment endpoints. """
@@ -248,6 +247,7 @@ def token_source_for(resource: str) -> TokenSource:
248247
use_params=True)
249248

250249
_ensure_host_present(cfg, token_source_for)
250+
cfg.load_azure_tenant_id()
251251
logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id)
252252
inner = token_source_for(cfg.effective_azure_login_app_id)
253253
cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
@@ -432,11 +432,13 @@ def refresh(self) -> Token:
432432
class AzureCliTokenSource(CliTokenSource):
433433
""" Obtain the token granted by `az login` CLI command """
434434

435-
def __init__(self, resource: str, subscription: str = ""):
435+
def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None):
436436
cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
437-
if subscription != "":
437+
if subscription is not None:
438438
cmd.append("--subscription")
439439
cmd.append(subscription)
440+
if tenant:
441+
cmd.extend(["--tenant", tenant])
440442
super().__init__(cmd=cmd,
441443
token_type_field='tokenType',
442444
access_token_field='accessToken',
@@ -464,8 +466,10 @@ def is_human_user(self) -> bool:
464466
@staticmethod
465467
def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
466468
subscription = AzureCliTokenSource.get_subscription(cfg)
467-
if subscription != "":
468-
token_source = AzureCliTokenSource(resource, subscription)
469+
if subscription is not None:
470+
token_source = AzureCliTokenSource(resource,
471+
subscription=subscription,
472+
tenant=cfg.azure_tenant_id)
469473
try:
470474
# This will fail if the user has access to the workspace, but not to the subscription
471475
# itself.
@@ -475,25 +479,26 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
475479
except OSError:
476480
logger.warning("Failed to get token for subscription. Using resource only token.")
477481

478-
token_source = AzureCliTokenSource(resource)
482+
token_source = AzureCliTokenSource(resource, subscription=None, tenant=cfg.azure_tenant_id)
479483
token_source.token()
480484
return token_source
481485

482486
@staticmethod
483-
def get_subscription(cfg: 'Config') -> str:
487+
def get_subscription(cfg: 'Config') -> Optional[str]:
484488
resource = cfg.azure_workspace_resource_id
485489
if resource is None or resource == "":
486-
return ""
490+
return None
487491
components = resource.split('/')
488492
if len(components) < 3:
489493
logger.warning("Invalid azure workspace resource ID")
490-
return ""
494+
return None
491495
return components[2]
492496

493497

494498
@credentials_strategy('azure-cli', ['is_azure'])
495499
def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
496500
""" Adds refreshed OAuth token granted by `az login` command to every request. """
501+
cfg.load_azure_tenant_id()
497502
token_source = None
498503
mgmt_token_source = None
499504
try:
@@ -517,11 +522,6 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
517522

518523
_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
519524
logger.info("Using Azure CLI authentication with AAD tokens")
520-
if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "":
521-
logger.warning(
522-
"azure_workspace_resource_id field not provided. "
523-
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors."
524-
)
525525

526526
def inner() -> Dict[str, str]:
527527
token = token_source.token()

tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,16 @@ def set_az_path(monkeypatch):
7777
monkeypatch.setenv('COMSPEC', 'C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe')
7878
else:
7979
monkeypatch.setenv('PATH', __tests__ + "/testdata:/bin")
80+
81+
82+
@pytest.fixture
83+
def mock_tenant(requests_mock):
84+
85+
def stub_tenant_request(host, tenant_id="test-tenant-id"):
86+
mock = requests_mock.get(
87+
f'https://{host}/aad/auth',
88+
status_code=302,
89+
headers={'Location': f'https://login.microsoftonline.com/{tenant_id}/oauth2/authorize'})
90+
return mock
91+
92+
return stub_tenant_request

tests/test_auth.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,10 @@ def test_config_azure_pat():
193193
assert cfg.is_azure
194194

195195

196-
def test_config_azure_cli_host(monkeypatch):
196+
def test_config_azure_cli_host(monkeypatch, mock_tenant):
197197
set_home(monkeypatch, '/testdata/azure')
198198
set_az_path(monkeypatch)
199+
mock_tenant('adb-123.4.azuredatabricks.net')
199200
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')
200201

201202
assert cfg.auth_type == 'azure-cli'
@@ -229,20 +230,22 @@ def test_config_azure_cli_host_pat_conflict_with_config_file_present_without_def
229230
cfg = Config(token='x', azure_workspace_resource_id='/sub/rg/ws')
230231

231232

232-
def test_config_azure_cli_host_and_resource_id(monkeypatch):
233+
def test_config_azure_cli_host_and_resource_id(monkeypatch, mock_tenant):
233234
set_home(monkeypatch, '/testdata')
234235
set_az_path(monkeypatch)
236+
mock_tenant('adb-123.4.azuredatabricks.net')
235237
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')
236238

237239
assert cfg.auth_type == 'azure-cli'
238240
assert cfg.host == 'https://adb-123.4.azuredatabricks.net'
239241
assert cfg.is_azure
240242

241243

242-
def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch):
244+
def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch, mock_tenant):
243245
monkeypatch.setenv('DATABRICKS_CONFIG_PROFILE', 'justhost')
244246
set_home(monkeypatch, '/testdata/azure')
245247
set_az_path(monkeypatch)
248+
mock_tenant('adb-123.4.azuredatabricks.net')
246249
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')
247250

248251
assert cfg.auth_type == 'azure-cli'

tests/test_auth_manual_tests.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from .conftest import set_az_path, set_home
44

55

6-
def test_azure_cli_workspace_header_present(monkeypatch):
6+
def test_azure_cli_workspace_header_present(monkeypatch, mock_tenant):
77
set_home(monkeypatch, '/testdata/azure')
88
set_az_path(monkeypatch)
9+
mock_tenant('adb-123.4.azuredatabricks.net')
910
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
1011
cfg = Config(auth_type='azure-cli',
1112
host='https://adb-123.4.azuredatabricks.net',
@@ -14,19 +15,21 @@ def test_azure_cli_workspace_header_present(monkeypatch):
1415
assert cfg.authenticate()['X-Databricks-Azure-Workspace-Resource-Id'] == resource_id
1516

1617

17-
def test_azure_cli_user_with_management_access(monkeypatch):
18+
def test_azure_cli_user_with_management_access(monkeypatch, mock_tenant):
1819
set_home(monkeypatch, '/testdata/azure')
1920
set_az_path(monkeypatch)
21+
mock_tenant('adb-123.4.azuredatabricks.net')
2022
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
2123
cfg = Config(auth_type='azure-cli',
2224
host='https://adb-123.4.azuredatabricks.net',
2325
azure_workspace_resource_id=resource_id)
2426
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()
2527

2628

27-
def test_azure_cli_user_no_management_access(monkeypatch):
29+
def test_azure_cli_user_no_management_access(monkeypatch, mock_tenant):
2830
set_home(monkeypatch, '/testdata/azure')
2931
set_az_path(monkeypatch)
32+
mock_tenant('adb-123.4.azuredatabricks.net')
3033
monkeypatch.setenv('FAIL_IF', 'https://management.core.windows.net/')
3134
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
3235
cfg = Config(auth_type='azure-cli',
@@ -35,9 +38,10 @@ def test_azure_cli_user_no_management_access(monkeypatch):
3538
assert 'X-Databricks-Azure-SP-Management-Token' not in cfg.authenticate()
3639

3740

38-
def test_azure_cli_fallback(monkeypatch):
41+
def test_azure_cli_fallback(monkeypatch, mock_tenant):
3942
set_home(monkeypatch, '/testdata/azure')
4043
set_az_path(monkeypatch)
44+
mock_tenant('adb-123.4.azuredatabricks.net')
4145
monkeypatch.setenv('FAIL_IF', 'subscription')
4246
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
4347
cfg = Config(auth_type='azure-cli',
@@ -46,9 +50,10 @@ def test_azure_cli_fallback(monkeypatch):
4650
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()
4751

4852

49-
def test_azure_cli_with_warning_on_stderr(monkeypatch):
53+
def test_azure_cli_with_warning_on_stderr(monkeypatch, mock_tenant):
5054
set_home(monkeypatch, '/testdata/azure')
5155
set_az_path(monkeypatch)
56+
mock_tenant('adb-123.4.azuredatabricks.net')
5257
monkeypatch.setenv('WARN', 'this is a warning')
5358
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
5459
cfg = Config(auth_type='azure-cli',

tests/test_config.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import platform
23

34
import pytest
@@ -6,7 +7,9 @@
67
from databricks.sdk.config import Config, with_product, with_user_agent_extra
78
from databricks.sdk.version import __version__
89

9-
from .conftest import noop_credentials
10+
from .conftest import noop_credentials, set_az_path
11+
12+
__tests__ = os.path.dirname(__file__)
1013

1114

1215
def test_config_supports_legacy_credentials_provider():
@@ -74,3 +77,40 @@ def test_config_copy_deep_copies_user_agent_other_info(config):
7477
assert "blueprint/0.4.6" in config.user_agent
7578
assert "blueprint/0.4.6" in config_copy.user_agent
7679
useragent._reset_extra(original_extra)
80+
81+
82+
def test_load_azure_tenant_id_404(requests_mock, monkeypatch):
83+
set_az_path(monkeypatch)
84+
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404)
85+
cfg = Config(host="https://abc123.azuredatabricks.net")
86+
assert cfg.azure_tenant_id is None
87+
assert mock.called_once
88+
89+
90+
def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch):
91+
set_az_path(monkeypatch)
92+
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302)
93+
cfg = Config(host="https://abc123.azuredatabricks.net")
94+
assert cfg.azure_tenant_id is None
95+
assert mock.called_once
96+
97+
98+
def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch):
99+
set_az_path(monkeypatch)
100+
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth',
101+
status_code=302,
102+
headers={'Location': 'https://unexpected-location'})
103+
cfg = Config(host="https://abc123.azuredatabricks.net")
104+
assert cfg.azure_tenant_id is None
105+
assert mock.called_once
106+
107+
108+
def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch):
109+
set_az_path(monkeypatch)
110+
mock = requests_mock.get(
111+
'https://abc123.azuredatabricks.net/aad/auth',
112+
status_code=302,
113+
headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'})
114+
cfg = Config(host="https://abc123.azuredatabricks.net")
115+
assert cfg.azure_tenant_id == 'tenant-id'
116+
assert mock.called_once

0 commit comments

Comments
 (0)