diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index efc02e856..e01c87b28 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,10 +3,12 @@ ## Release v0.60.0 ### New Features and Improvements + * Added headers to HttpRequestResponse in OpenAI client. ### Bug Fixes +- Correctly issue in OIDC implementation that prevented the use of the feature (see #994). - Fix a reported issue where `FilesExt` fails to retry if it receives certain status code from server. ### Documentation diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 2f5121180..8f8d54624 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -331,7 +331,7 @@ def file_oidc(cfg) -> Optional[CredentialsProvider]: # that provides a Databricks token from an IdTokenSource. def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optional[CredentialsProvider]: try: - id_token = id_token_source.id_token() + id_token_source.id_token() # validate the id_token_source except Exception as e: logger.debug(f"Failed to get OIDC token: {e}") return None @@ -341,7 +341,7 @@ def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Opti token_endpoint=cfg.oidc_endpoints.token_endpoint, client_id=cfg.client_id, account_id=cfg.account_id, - id_token=id_token, + id_token_source=id_token_source, disable_async=cfg.disable_async_token_refresh, ) diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py index 5c0af2949..9f39e3d72 100644 --- a/databricks/sdk/oidc.py +++ b/databricks/sdk/oidc.py @@ -188,14 +188,18 @@ def token(self) -> oauth.Token: logger.debug("Client ID provided, authenticating with Workload Identity Federation") id_token = self._id_token_source.id_token() + return self._exchange_id_token(id_token) + # This function is used to create the OAuth client. + # It exists to make it easier to test. + def _exchange_id_token(self, id_token: IdToken) -> oauth.Token: client = oauth.ClientCredentials( client_id=self._client_id, - client_secret="", # we have no (rotatable) secrets in OIDC flow + client_secret="", token_url=self._token_endpoint, endpoint_params={ "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", - "subject_token": id_token, + "subject_token": id_token.jwt, "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", }, scopes=["all-apis"], diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index b23044d7c..13f16531c 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -1,8 +1,10 @@ +from datetime import datetime, timedelta from unittest.mock import Mock -from databricks.sdk.credentials_provider import external_browser +from databricks.sdk import credentials_provider, oauth, oidc +# Tests for external_browser function def test_external_browser_refresh_success(mocker): """Tests successful refresh of existing credentials.""" @@ -21,7 +23,9 @@ def test_external_browser_refresh_success(mocker): mock_token_cache.load.return_value = mock_session_credentials # Mock SessionCredentials. - want_credentials_provider = lambda c: "new_credentials" + def want_credentials_provider(_): + return "new_credentials" + mock_session_credentials.return_value = want_credentials_provider # Inject the mock implementations. @@ -30,7 +34,7 @@ def test_external_browser_refresh_success(mocker): return_value=mock_token_cache, ) - got_credentials_provider = external_browser(mock_cfg) + got_credentials_provider = credentials_provider.external_browser(mock_cfg) mock_token_cache.load.assert_called_once() mock_session_credentials.token.assert_called_once() # Verify token refresh was attempted @@ -55,7 +59,9 @@ def test_external_browser_refresh_failure_new_oauth_flow(mocker): mock_token_cache.load.return_value = mock_session_credentials # Mock SessionCredentials. - want_credentials_provider = lambda c: "new_credentials" + def want_credentials_provider(_): + return "new_credentials" + mock_session_credentials.return_value = want_credentials_provider # Mock OAuthClient. @@ -74,7 +80,7 @@ def test_external_browser_refresh_failure_new_oauth_flow(mocker): return_value=mock_oauth_client, ) - got_credentials_provider = external_browser(mock_cfg) + got_credentials_provider = credentials_provider.external_browser(mock_cfg) mock_token_cache.load.assert_called_once() mock_session_credentials.token.assert_called_once() # Refresh attempt @@ -101,7 +107,10 @@ def test_external_browser_no_cached_credentials(mocker): # Mock SessionCredentials. mock_session_credentials = Mock() - want_credentials_provider = lambda c: "new_credentials" + + def want_credentials_provider(_): + return "new_credentials" + mock_session_credentials.return_value = want_credentials_provider # Mock OAuthClient. @@ -120,7 +129,7 @@ def test_external_browser_no_cached_credentials(mocker): return_value=mock_oauth_client, ) - got_credentials_provider = external_browser(mock_cfg) + got_credentials_provider = credentials_provider.external_browser(mock_cfg) mock_token_cache.load.assert_called_once() mock_oauth_client.initiate_consent.assert_called_once() @@ -158,8 +167,58 @@ def test_external_browser_consent_fails(mocker): return_value=mock_oauth_client, ) - got_credentials_provider = external_browser(mock_cfg) + got_credentials_provider = credentials_provider.external_browser(mock_cfg) mock_token_cache.load.assert_called_once() mock_oauth_client.initiate_consent.assert_called_once() assert got_credentials_provider is None + + +def test_oidc_credentials_provider_invalid_id_token_source(): + # Use a mock config object to avoid initializing the auth initialization. + mock_cfg = Mock() + mock_cfg.host = "https://test-workspace.cloud.databricks.com" + mock_cfg.oidc_endpoints = Mock() + mock_cfg.oidc_endpoints.token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token" + mock_cfg.client_id = "test-client-id" + mock_cfg.account_id = "test-account-id" + mock_cfg.disable_async_token_refresh = True + + # An IdTokenSource that raises an error when id_token() is called. + id_token_source = Mock() + id_token_source.id_token.side_effect = ValueError("Invalid ID token source") + + cp = credentials_provider._oidc_credentials_provider(mock_cfg, id_token_source) + assert cp is None + + +def test_oidc_credentials_provider_valid_id_token_source(mocker): + # Use a mock config object to avoid initializing the auth initialization. + mock_cfg = Mock() + mock_cfg.host = "https://test-workspace.cloud.databricks.com" + mock_cfg.oidc_endpoints = Mock() + mock_cfg.oidc_endpoints.token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token" + mock_cfg.client_id = "test-client-id" + mock_cfg.account_id = "test-account-id" + mock_cfg.disable_async_token_refresh = True + + # A valid IdTokenSource that never raises an error. + id_token_source = Mock() + id_token_source.id_token.return_value = oidc.IdToken(jwt="test-jwt-token") + + # Mock the _exchange_id_token method on DatabricksOidcTokenSource to return + # a valid oauth.Token based on the IdToken. + def mock_exchange_id_token(id_token: oidc.IdToken): + # Create a token based on the input ID token + return oauth.Token( + access_token=f"exchanged-{id_token.jwt}", token_type="Bearer", expiry=datetime.now() + timedelta(hours=1) + ) + + mocker.patch.object(oidc.DatabricksOidcTokenSource, "_exchange_id_token", side_effect=mock_exchange_id_token) + + cp = credentials_provider._oidc_credentials_provider(mock_cfg, id_token_source) + assert cp is not None + + # Test that the credentials provider returns the expected headers + headers = cp() + assert headers == {"Authorization": "Bearer exchanged-test-jwt-token"}