Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
## Release v0.60.0

### New Features and Improvements

* Added headers to HttpRequestResponse in OpenAI client.

### Bug Fixes

- Correctly issue in OIDC implementation that prevented the use of the feature (see #994).
- Fix a reported issue where `FilesExt` fails to retry if it receives certain status code from server.

### Documentation
Expand Down
4 changes: 2 additions & 2 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def file_oidc(cfg) -> Optional[CredentialsProvider]:
# that provides a Databricks token from an IdTokenSource.
def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optional[CredentialsProvider]:
try:
id_token = id_token_source.id_token()
id_token_source.id_token() # validate the id_token_source
except Exception as e:
logger.debug(f"Failed to get OIDC token: {e}")
return None
Expand All @@ -341,7 +341,7 @@ def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Opti
token_endpoint=cfg.oidc_endpoints.token_endpoint,
client_id=cfg.client_id,
account_id=cfg.account_id,
id_token=id_token,
id_token_source=id_token_source,
disable_async=cfg.disable_async_token_refresh,
)

Expand Down
8 changes: 6 additions & 2 deletions databricks/sdk/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,18 @@ def token(self) -> oauth.Token:
logger.debug("Client ID provided, authenticating with Workload Identity Federation")

id_token = self._id_token_source.id_token()
return self._exchange_id_token(id_token)

# This function is used to create the OAuth client.
# It exists to make it easier to test.
def _exchange_id_token(self, id_token: IdToken) -> oauth.Token:
client = oauth.ClientCredentials(
client_id=self._client_id,
client_secret="", # we have no (rotatable) secrets in OIDC flow
client_secret="",
token_url=self._token_endpoint,
endpoint_params={
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
"subject_token": id_token,
"subject_token": id_token.jwt,
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
},
scopes=["all-apis"],
Expand Down
75 changes: 67 additions & 8 deletions tests/test_credentials_provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from datetime import datetime, timedelta
from unittest.mock import Mock

from databricks.sdk.credentials_provider import external_browser
from databricks.sdk import credentials_provider, oauth, oidc


# Tests for external_browser function
def test_external_browser_refresh_success(mocker):
"""Tests successful refresh of existing credentials."""

Expand All @@ -21,7 +23,9 @@ def test_external_browser_refresh_success(mocker):
mock_token_cache.load.return_value = mock_session_credentials

# Mock SessionCredentials.
want_credentials_provider = lambda c: "new_credentials"
def want_credentials_provider(_):
return "new_credentials"

mock_session_credentials.return_value = want_credentials_provider

# Inject the mock implementations.
Expand All @@ -30,7 +34,7 @@ def test_external_browser_refresh_success(mocker):
return_value=mock_token_cache,
)

got_credentials_provider = external_browser(mock_cfg)
got_credentials_provider = credentials_provider.external_browser(mock_cfg)

mock_token_cache.load.assert_called_once()
mock_session_credentials.token.assert_called_once() # Verify token refresh was attempted
Expand All @@ -55,7 +59,9 @@ def test_external_browser_refresh_failure_new_oauth_flow(mocker):
mock_token_cache.load.return_value = mock_session_credentials

# Mock SessionCredentials.
want_credentials_provider = lambda c: "new_credentials"
def want_credentials_provider(_):
return "new_credentials"

mock_session_credentials.return_value = want_credentials_provider

# Mock OAuthClient.
Expand All @@ -74,7 +80,7 @@ def test_external_browser_refresh_failure_new_oauth_flow(mocker):
return_value=mock_oauth_client,
)

got_credentials_provider = external_browser(mock_cfg)
got_credentials_provider = credentials_provider.external_browser(mock_cfg)

mock_token_cache.load.assert_called_once()
mock_session_credentials.token.assert_called_once() # Refresh attempt
Expand All @@ -101,7 +107,10 @@ def test_external_browser_no_cached_credentials(mocker):

# Mock SessionCredentials.
mock_session_credentials = Mock()
want_credentials_provider = lambda c: "new_credentials"

def want_credentials_provider(_):
return "new_credentials"

mock_session_credentials.return_value = want_credentials_provider

# Mock OAuthClient.
Expand All @@ -120,7 +129,7 @@ def test_external_browser_no_cached_credentials(mocker):
return_value=mock_oauth_client,
)

got_credentials_provider = external_browser(mock_cfg)
got_credentials_provider = credentials_provider.external_browser(mock_cfg)

mock_token_cache.load.assert_called_once()
mock_oauth_client.initiate_consent.assert_called_once()
Expand Down Expand Up @@ -158,8 +167,58 @@ def test_external_browser_consent_fails(mocker):
return_value=mock_oauth_client,
)

got_credentials_provider = external_browser(mock_cfg)
got_credentials_provider = credentials_provider.external_browser(mock_cfg)

mock_token_cache.load.assert_called_once()
mock_oauth_client.initiate_consent.assert_called_once()
assert got_credentials_provider is None


def test_oidc_credentials_provider_invalid_id_token_source():
# Use a mock config object to avoid initializing the auth initialization.
mock_cfg = Mock()
mock_cfg.host = "https://test-workspace.cloud.databricks.com"
mock_cfg.oidc_endpoints = Mock()
mock_cfg.oidc_endpoints.token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token"
mock_cfg.client_id = "test-client-id"
mock_cfg.account_id = "test-account-id"
mock_cfg.disable_async_token_refresh = True

# An IdTokenSource that raises an error when id_token() is called.
id_token_source = Mock()
id_token_source.id_token.side_effect = ValueError("Invalid ID token source")

cp = credentials_provider._oidc_credentials_provider(mock_cfg, id_token_source)
assert cp is None


def test_oidc_credentials_provider_valid_id_token_source(mocker):
# Use a mock config object to avoid initializing the auth initialization.
mock_cfg = Mock()
mock_cfg.host = "https://test-workspace.cloud.databricks.com"
mock_cfg.oidc_endpoints = Mock()
mock_cfg.oidc_endpoints.token_endpoint = "https://test-workspace.cloud.databricks.com/oidc/v1/token"
mock_cfg.client_id = "test-client-id"
mock_cfg.account_id = "test-account-id"
mock_cfg.disable_async_token_refresh = True

# A valid IdTokenSource that never raises an error.
id_token_source = Mock()
id_token_source.id_token.return_value = oidc.IdToken(jwt="test-jwt-token")

# Mock the _exchange_id_token method on DatabricksOidcTokenSource to return
# a valid oauth.Token based on the IdToken.
def mock_exchange_id_token(id_token: oidc.IdToken):
# Create a token based on the input ID token
return oauth.Token(
access_token=f"exchanged-{id_token.jwt}", token_type="Bearer", expiry=datetime.now() + timedelta(hours=1)
)

mocker.patch.object(oidc.DatabricksOidcTokenSource, "_exchange_id_token", side_effect=mock_exchange_id_token)

cp = credentials_provider._oidc_credentials_provider(mock_cfg, id_token_source)
assert cp is not None

# Test that the credentials provider returns the expected headers
headers = cp()
assert headers == {"Authorization": "Bearer exchanged-test-jwt-token"}
Loading