1+ from datetime import datetime , timedelta
12from unittest .mock import Mock
23
3- from databricks .sdk .credentials_provider import external_browser
4+ from databricks .sdk import oauth , oidc
5+ from databricks .sdk .credentials_provider import (_oidc_credentials_provider ,
6+ external_browser )
47
58
9+ # Tests for external_browser function
610def test_external_browser_refresh_success (mocker ):
711 """Tests successful refresh of existing credentials."""
812
@@ -21,7 +25,9 @@ def test_external_browser_refresh_success(mocker):
2125 mock_token_cache .load .return_value = mock_session_credentials
2226
2327 # Mock SessionCredentials.
24- want_credentials_provider = lambda c : "new_credentials"
28+ def want_credentials_provider (_ ):
29+ return "new_credentials"
30+
2531 mock_session_credentials .return_value = want_credentials_provider
2632
2733 # Inject the mock implementations.
@@ -55,7 +61,9 @@ def test_external_browser_refresh_failure_new_oauth_flow(mocker):
5561 mock_token_cache .load .return_value = mock_session_credentials
5662
5763 # Mock SessionCredentials.
58- want_credentials_provider = lambda c : "new_credentials"
64+ def want_credentials_provider (_ ):
65+ return "new_credentials"
66+
5967 mock_session_credentials .return_value = want_credentials_provider
6068
6169 # Mock OAuthClient.
@@ -101,7 +109,10 @@ def test_external_browser_no_cached_credentials(mocker):
101109
102110 # Mock SessionCredentials.
103111 mock_session_credentials = Mock ()
104- want_credentials_provider = lambda c : "new_credentials"
112+
113+ def want_credentials_provider (_ ):
114+ return "new_credentials"
115+
105116 mock_session_credentials .return_value = want_credentials_provider
106117
107118 # Mock OAuthClient.
@@ -163,3 +174,53 @@ def test_external_browser_consent_fails(mocker):
163174 mock_token_cache .load .assert_called_once ()
164175 mock_oauth_client .initiate_consent .assert_called_once ()
165176 assert got_credentials_provider is None
177+
178+
179+ def test_oidc_credentials_provider_invalid_id_token_source ():
180+ # Use a mock config object to avoid initializing the auth initialization.
181+ mock_cfg = Mock ()
182+ mock_cfg .host = "https://test-workspace.cloud.databricks.com"
183+ mock_cfg .oidc_endpoints = Mock ()
184+ mock_cfg .oidc_endpoints .token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token"
185+ mock_cfg .client_id = "test-client-id"
186+ mock_cfg .account_id = "test-account-id"
187+ mock_cfg .disable_async_token_refresh = True
188+
189+ # An IdTokenSource that raises an error when id_token() is called.
190+ id_token_source = Mock ()
191+ id_token_source .id_token .side_effect = ValueError ("Invalid ID token source" )
192+
193+ credentials_provider = _oidc_credentials_provider (mock_cfg , id_token_source )
194+ assert credentials_provider is None
195+
196+
197+ def test_oidc_credentials_provider_valid_id_token_source (mocker ):
198+ # Use a mock config object to avoid initializing the auth initialization.
199+ mock_cfg = Mock ()
200+ mock_cfg .host = "https://test-workspace.cloud.databricks.com"
201+ mock_cfg .oidc_endpoints = Mock ()
202+ mock_cfg .oidc_endpoints .token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token"
203+ mock_cfg .client_id = "test-client-id"
204+ mock_cfg .account_id = "test-account-id"
205+ mock_cfg .disable_async_token_refresh = True
206+
207+ # A valid IdTokenSource that never raises an error.
208+ id_token_source = Mock ()
209+ id_token_source .id_token .return_value = oidc .IdToken (jwt = "test-jwt-token" )
210+
211+ # Mock the _exchange_id_token method on DatabricksOidcTokenSource to return
212+ # a valid oauth.Token based on the IdToken.
213+ def mock_exchange_id_token (id_token : oidc .IdToken ):
214+ # Create a token based on the input ID token
215+ return oauth .Token (
216+ access_token = f"exchanged-{ id_token .jwt } " , token_type = "Bearer" , expiry = datetime .now () + timedelta (hours = 1 )
217+ )
218+
219+ mocker .patch .object (oidc .DatabricksOidcTokenSource , "_exchange_id_token" , side_effect = mock_exchange_id_token )
220+
221+ credentials_provider = _oidc_credentials_provider (mock_cfg , id_token_source )
222+ assert credentials_provider is not None
223+
224+ # Test that the credentials provider returns the expected headers
225+ headers = credentials_provider ()
226+ assert headers == {"Authorization" : "Bearer exchanged-test-jwt-token" }
0 commit comments