From cd7c23622a4b7ea0291b912d74a6a039d45a5e91 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 23 Jul 2025 09:14:06 +0000 Subject: [PATCH 1/4] Fix OIDC issue + test --- databricks/sdk/credentials_provider.py | 4 +- databricks/sdk/oidc.py | 8 ++- tests/test_credentials_provider.py | 69 ++++++++++++++++++++++++-- 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 2f5121180..409828161 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -331,7 +331,7 @@ def file_oidc(cfg) -> Optional[CredentialsProvider]: # that provides a Databricks token from an IdTokenSource. def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optional[CredentialsProvider]: try: - id_token = id_token_source.id_token() + id_token_source.id_token() # validate the id_token_source is valid except Exception as e: logger.debug(f"Failed to get OIDC token: {e}") return None @@ -341,7 +341,7 @@ def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Opti token_endpoint=cfg.oidc_endpoints.token_endpoint, client_id=cfg.client_id, account_id=cfg.account_id, - id_token=id_token, + id_token_source=id_token_source, disable_async=cfg.disable_async_token_refresh, ) diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py index 5c0af2949..9f39e3d72 100644 --- a/databricks/sdk/oidc.py +++ b/databricks/sdk/oidc.py @@ -188,14 +188,18 @@ def token(self) -> oauth.Token: logger.debug("Client ID provided, authenticating with Workload Identity Federation") id_token = self._id_token_source.id_token() + return self._exchange_id_token(id_token) + # This function is used to create the OAuth client. + # It exists to make it easier to test. + def _exchange_id_token(self, id_token: IdToken) -> oauth.Token: client = oauth.ClientCredentials( client_id=self._client_id, - client_secret="", # we have no (rotatable) secrets in OIDC flow + client_secret="", token_url=self._token_endpoint, endpoint_params={ "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", - "subject_token": id_token, + "subject_token": id_token.jwt, "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", }, scopes=["all-apis"], diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index b23044d7c..96e5fd37b 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -1,8 +1,12 @@ +from datetime import datetime, timedelta from unittest.mock import Mock -from databricks.sdk.credentials_provider import external_browser +from databricks.sdk import oauth, oidc +from databricks.sdk.credentials_provider import (_oidc_credentials_provider, + external_browser) +# Tests for external_browser function def test_external_browser_refresh_success(mocker): """Tests successful refresh of existing credentials.""" @@ -21,7 +25,9 @@ def test_external_browser_refresh_success(mocker): mock_token_cache.load.return_value = mock_session_credentials # Mock SessionCredentials. - want_credentials_provider = lambda c: "new_credentials" + def want_credentials_provider(_): + return "new_credentials" + mock_session_credentials.return_value = want_credentials_provider # Inject the mock implementations. @@ -55,7 +61,9 @@ def test_external_browser_refresh_failure_new_oauth_flow(mocker): mock_token_cache.load.return_value = mock_session_credentials # Mock SessionCredentials. - want_credentials_provider = lambda c: "new_credentials" + def want_credentials_provider(_): + return "new_credentials" + mock_session_credentials.return_value = want_credentials_provider # Mock OAuthClient. @@ -101,7 +109,10 @@ def test_external_browser_no_cached_credentials(mocker): # Mock SessionCredentials. mock_session_credentials = Mock() - want_credentials_provider = lambda c: "new_credentials" + + def want_credentials_provider(_): + return "new_credentials" + mock_session_credentials.return_value = want_credentials_provider # Mock OAuthClient. @@ -163,3 +174,53 @@ def test_external_browser_consent_fails(mocker): mock_token_cache.load.assert_called_once() mock_oauth_client.initiate_consent.assert_called_once() assert got_credentials_provider is None + + +def test_oidc_credentials_provider_invalid_id_token_source(): + # Use a mock config object to avoid initializing the auth initialization. + mock_cfg = Mock() + mock_cfg.host = "https://test-workspace.cloud.databricks.com" + mock_cfg.oidc_endpoints = Mock() + mock_cfg.oidc_endpoints.token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token" + mock_cfg.client_id = "test-client-id" + mock_cfg.account_id = "test-account-id" + mock_cfg.disable_async_token_refresh = True + + # An IdTokenSource that raises an error when id_token() is called. + id_token_source = Mock() + id_token_source.id_token.side_effect = ValueError("Invalid ID token source") + + credentials_provider = _oidc_credentials_provider(mock_cfg, id_token_source) + assert credentials_provider is None + + +def test_oidc_credentials_provider_valid_id_token_source(mocker): + # Use a mock config object to avoid initializing the auth initialization. + mock_cfg = Mock() + mock_cfg.host = "https://test-workspace.cloud.databricks.com" + mock_cfg.oidc_endpoints = Mock() + mock_cfg.oidc_endpoints.token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token" + mock_cfg.client_id = "test-client-id" + mock_cfg.account_id = "test-account-id" + mock_cfg.disable_async_token_refresh = True + + # A valid IdTokenSource that never raises an error. + id_token_source = Mock() + id_token_source.id_token.return_value = oidc.IdToken(jwt="test-jwt-token") + + # Mock the _exchange_id_token method on DatabricksOidcTokenSource to return + # a valid oauth.Token based on the IdToken. + def mock_exchange_id_token(id_token: oidc.IdToken): + # Create a token based on the input ID token + return oauth.Token( + access_token=f"exchanged-{id_token.jwt}", token_type="Bearer", expiry=datetime.now() + timedelta(hours=1) + ) + + mocker.patch.object(oidc.DatabricksOidcTokenSource, "_exchange_id_token", side_effect=mock_exchange_id_token) + + credentials_provider = _oidc_credentials_provider(mock_cfg, id_token_source) + assert credentials_provider is not None + + # Test that the credentials provider returns the expected headers + headers = credentials_provider() + assert headers == {"Authorization": "Bearer exchanged-test-jwt-token"} From 3f2e81ed57900d8046996b551cc4cade0f4376cf Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 23 Jul 2025 09:18:07 +0000 Subject: [PATCH 2/4] Add changelog --- NEXT_CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index efc02e856..e01c87b28 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,10 +3,12 @@ ## Release v0.60.0 ### New Features and Improvements + * Added headers to HttpRequestResponse in OpenAI client. ### Bug Fixes +- Correctly issue in OIDC implementation that prevented the use of the feature (see #994). - Fix a reported issue where `FilesExt` fails to retry if it receives certain status code from server. ### Documentation From 0d01a0b6e6143146a93b42501e0f97bafa091b64 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 23 Jul 2025 09:20:27 +0000 Subject: [PATCH 3/4] Minor formatting change --- tests/test_credentials_provider.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index 96e5fd37b..13f16531c 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -1,9 +1,7 @@ from datetime import datetime, timedelta from unittest.mock import Mock -from databricks.sdk import oauth, oidc -from databricks.sdk.credentials_provider import (_oidc_credentials_provider, - external_browser) +from databricks.sdk import credentials_provider, oauth, oidc # Tests for external_browser function @@ -36,7 +34,7 @@ def want_credentials_provider(_): return_value=mock_token_cache, ) - got_credentials_provider = external_browser(mock_cfg) + got_credentials_provider = credentials_provider.external_browser(mock_cfg) mock_token_cache.load.assert_called_once() mock_session_credentials.token.assert_called_once() # Verify token refresh was attempted @@ -82,7 +80,7 @@ def want_credentials_provider(_): return_value=mock_oauth_client, ) - got_credentials_provider = external_browser(mock_cfg) + got_credentials_provider = credentials_provider.external_browser(mock_cfg) mock_token_cache.load.assert_called_once() mock_session_credentials.token.assert_called_once() # Refresh attempt @@ -131,7 +129,7 @@ def want_credentials_provider(_): return_value=mock_oauth_client, ) - got_credentials_provider = external_browser(mock_cfg) + got_credentials_provider = credentials_provider.external_browser(mock_cfg) mock_token_cache.load.assert_called_once() mock_oauth_client.initiate_consent.assert_called_once() @@ -169,7 +167,7 @@ def test_external_browser_consent_fails(mocker): return_value=mock_oauth_client, ) - got_credentials_provider = external_browser(mock_cfg) + got_credentials_provider = credentials_provider.external_browser(mock_cfg) mock_token_cache.load.assert_called_once() mock_oauth_client.initiate_consent.assert_called_once() @@ -190,8 +188,8 @@ def test_oidc_credentials_provider_invalid_id_token_source(): id_token_source = Mock() id_token_source.id_token.side_effect = ValueError("Invalid ID token source") - credentials_provider = _oidc_credentials_provider(mock_cfg, id_token_source) - assert credentials_provider is None + cp = credentials_provider._oidc_credentials_provider(mock_cfg, id_token_source) + assert cp is None def test_oidc_credentials_provider_valid_id_token_source(mocker): @@ -218,9 +216,9 @@ def mock_exchange_id_token(id_token: oidc.IdToken): mocker.patch.object(oidc.DatabricksOidcTokenSource, "_exchange_id_token", side_effect=mock_exchange_id_token) - credentials_provider = _oidc_credentials_provider(mock_cfg, id_token_source) - assert credentials_provider is not None + cp = credentials_provider._oidc_credentials_provider(mock_cfg, id_token_source) + assert cp is not None # Test that the credentials provider returns the expected headers - headers = credentials_provider() + headers = cp() assert headers == {"Authorization": "Bearer exchanged-test-jwt-token"} From 23740ad7a0b1d85a77bbbf1cbaa97f770fe2c019 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 23 Jul 2025 09:21:12 +0000 Subject: [PATCH 4/4] Minor formatting change --- databricks/sdk/credentials_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 409828161..8f8d54624 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -331,7 +331,7 @@ def file_oidc(cfg) -> Optional[CredentialsProvider]: # that provides a Databricks token from an IdTokenSource. def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optional[CredentialsProvider]: try: - id_token_source.id_token() # validate the id_token_source is valid + id_token_source.id_token() # validate the id_token_source except Exception as e: logger.debug(f"Failed to get OIDC token: {e}") return None