Skip to content

Commit cd7c236

Browse files
Fix OIDC issue + test
1 parent 0ec0dcb commit cd7c236

File tree

3 files changed

+73
-8
lines changed

3 files changed

+73
-8
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def file_oidc(cfg) -> Optional[CredentialsProvider]:
331331
# that provides a Databricks token from an IdTokenSource.
332332
def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optional[CredentialsProvider]:
333333
try:
334-
id_token = id_token_source.id_token()
334+
id_token_source.id_token() # validate the id_token_source is valid
335335
except Exception as e:
336336
logger.debug(f"Failed to get OIDC token: {e}")
337337
return None
@@ -341,7 +341,7 @@ def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Opti
341341
token_endpoint=cfg.oidc_endpoints.token_endpoint,
342342
client_id=cfg.client_id,
343343
account_id=cfg.account_id,
344-
id_token=id_token,
344+
id_token_source=id_token_source,
345345
disable_async=cfg.disable_async_token_refresh,
346346
)
347347

databricks/sdk/oidc.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,18 @@ def token(self) -> oauth.Token:
188188
logger.debug("Client ID provided, authenticating with Workload Identity Federation")
189189

190190
id_token = self._id_token_source.id_token()
191+
return self._exchange_id_token(id_token)
191192

193+
# This function is used to create the OAuth client.
194+
# It exists to make it easier to test.
195+
def _exchange_id_token(self, id_token: IdToken) -> oauth.Token:
192196
client = oauth.ClientCredentials(
193197
client_id=self._client_id,
194-
client_secret="", # we have no (rotatable) secrets in OIDC flow
198+
client_secret="",
195199
token_url=self._token_endpoint,
196200
endpoint_params={
197201
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
198-
"subject_token": id_token,
202+
"subject_token": id_token.jwt,
199203
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
200204
},
201205
scopes=["all-apis"],

tests/test_credentials_provider.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from datetime import datetime, timedelta
12
from unittest.mock import Mock
23

3-
from databricks.sdk.credentials_provider import external_browser
4+
from databricks.sdk import oauth, oidc
5+
from databricks.sdk.credentials_provider import (_oidc_credentials_provider,
6+
external_browser)
47

58

9+
# Tests for external_browser function
610
def test_external_browser_refresh_success(mocker):
711
"""Tests successful refresh of existing credentials."""
812

@@ -21,7 +25,9 @@ def test_external_browser_refresh_success(mocker):
2125
mock_token_cache.load.return_value = mock_session_credentials
2226

2327
# Mock SessionCredentials.
24-
want_credentials_provider = lambda c: "new_credentials"
28+
def want_credentials_provider(_):
29+
return "new_credentials"
30+
2531
mock_session_credentials.return_value = want_credentials_provider
2632

2733
# Inject the mock implementations.
@@ -55,7 +61,9 @@ def test_external_browser_refresh_failure_new_oauth_flow(mocker):
5561
mock_token_cache.load.return_value = mock_session_credentials
5662

5763
# Mock SessionCredentials.
58-
want_credentials_provider = lambda c: "new_credentials"
64+
def want_credentials_provider(_):
65+
return "new_credentials"
66+
5967
mock_session_credentials.return_value = want_credentials_provider
6068

6169
# Mock OAuthClient.
@@ -101,7 +109,10 @@ def test_external_browser_no_cached_credentials(mocker):
101109

102110
# Mock SessionCredentials.
103111
mock_session_credentials = Mock()
104-
want_credentials_provider = lambda c: "new_credentials"
112+
113+
def want_credentials_provider(_):
114+
return "new_credentials"
115+
105116
mock_session_credentials.return_value = want_credentials_provider
106117

107118
# Mock OAuthClient.
@@ -163,3 +174,53 @@ def test_external_browser_consent_fails(mocker):
163174
mock_token_cache.load.assert_called_once()
164175
mock_oauth_client.initiate_consent.assert_called_once()
165176
assert got_credentials_provider is None
177+
178+
179+
def test_oidc_credentials_provider_invalid_id_token_source():
180+
# Use a mock config object to avoid initializing the auth initialization.
181+
mock_cfg = Mock()
182+
mock_cfg.host = "https://test-workspace.cloud.databricks.com"
183+
mock_cfg.oidc_endpoints = Mock()
184+
mock_cfg.oidc_endpoints.token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token"
185+
mock_cfg.client_id = "test-client-id"
186+
mock_cfg.account_id = "test-account-id"
187+
mock_cfg.disable_async_token_refresh = True
188+
189+
# An IdTokenSource that raises an error when id_token() is called.
190+
id_token_source = Mock()
191+
id_token_source.id_token.side_effect = ValueError("Invalid ID token source")
192+
193+
credentials_provider = _oidc_credentials_provider(mock_cfg, id_token_source)
194+
assert credentials_provider is None
195+
196+
197+
def test_oidc_credentials_provider_valid_id_token_source(mocker):
198+
# Use a mock config object to avoid initializing the auth initialization.
199+
mock_cfg = Mock()
200+
mock_cfg.host = "https://test-workspace.cloud.databricks.com"
201+
mock_cfg.oidc_endpoints = Mock()
202+
mock_cfg.oidc_endpoints.token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token"
203+
mock_cfg.client_id = "test-client-id"
204+
mock_cfg.account_id = "test-account-id"
205+
mock_cfg.disable_async_token_refresh = True
206+
207+
# A valid IdTokenSource that never raises an error.
208+
id_token_source = Mock()
209+
id_token_source.id_token.return_value = oidc.IdToken(jwt="test-jwt-token")
210+
211+
# Mock the _exchange_id_token method on DatabricksOidcTokenSource to return
212+
# a valid oauth.Token based on the IdToken.
213+
def mock_exchange_id_token(id_token: oidc.IdToken):
214+
# Create a token based on the input ID token
215+
return oauth.Token(
216+
access_token=f"exchanged-{id_token.jwt}", token_type="Bearer", expiry=datetime.now() + timedelta(hours=1)
217+
)
218+
219+
mocker.patch.object(oidc.DatabricksOidcTokenSource, "_exchange_id_token", side_effect=mock_exchange_id_token)
220+
221+
credentials_provider = _oidc_credentials_provider(mock_cfg, id_token_source)
222+
assert credentials_provider is not None
223+
224+
# Test that the credentials provider returns the expected headers
225+
headers = credentials_provider()
226+
assert headers == {"Authorization": "Bearer exchanged-test-jwt-token"}

0 commit comments

Comments
 (0)