diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 5fc234f9f..a5721cff8 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,8 @@ ### Bug Fixes +* Fix `Config.oauth_token()` to avoid re-creating a new `CredentialsProvider` at each call. This fix indirectly makes `oauth_token()` benefit from the internal caching mechanism of some providers. + ### Documentation ### Internal Changes diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 487527be7..4cfd8b4f9 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -13,7 +13,8 @@ from . import useragent from ._base_client import _fix_host_if_needed from .clock import Clock, RealClock -from .credentials_provider import CredentialsStrategy, DefaultCredentials +from .credentials_provider import (CredentialsStrategy, DefaultCredentials, + OAuthCredentialsProvider) from .environments import (ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname) from .oauth import (OidcEndpoints, Token, get_account_endpoints, @@ -200,7 +201,19 @@ def __init__( raise ValueError(message) from e def oauth_token(self) -> Token: - return self._credentials_strategy.oauth_token(self) + """Returns the OAuth token from the current credential provider. + + This method only works when using OAuth-based authentication methods. + If the current credential provider is an OAuthCredentialsProvider, it reuses + the existing provider. Otherwise, it raises a ValueError indicating that + OAuth tokens are not available for the current authentication method. + """ + if isinstance(self._header_factory, OAuthCredentialsProvider): + return self._header_factory.oauth_token() + raise ValueError( + f"OAuth tokens are not available for {self.auth_type} authentication. " + f"Use an OAuth-based authentication method to access OAuth tokens." + ) def wrap_debug_info(self, message: str) -> str: debug_string = self.debug_string() diff --git a/tests/test_config.py b/tests/test_config.py index b023123af..59fbf8712 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -189,3 +189,74 @@ def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch): cfg = Config(host="https://abc123.azuredatabricks.net") assert cfg.azure_tenant_id == "tenant-id" assert mock.called_once + + +def test_oauth_token_with_pat_auth(): + """Test that oauth_token() raises an error for PAT authentication.""" + config = Config(host="https://test.databricks.com", token="dapi1234567890abcdef") + + with pytest.raises(ValueError) as exc_info: + config.oauth_token() + + assert "OAuth tokens are not available for pat authentication" in str(exc_info.value) + + +def test_oauth_token_with_basic_auth(): + """Test that oauth_token() raises an error for basic authentication.""" + config = Config(host="https://test.databricks.com", username="testuser", password="testpass") + + with pytest.raises(ValueError) as exc_info: + config.oauth_token() + + assert "OAuth tokens are not available for basic authentication" in str(exc_info.value) + + +def test_oauth_token_with_oauth_provider(mocker): + """Test that oauth_token() works correctly for OAuth authentication.""" + from databricks.sdk.credentials_provider import OAuthCredentialsProvider + from databricks.sdk.oauth import Token + + # Create a mock OAuth token + mock_token = Token(access_token="mock_access_token", token_type="Bearer", refresh_token="mock_refresh_token") + + # Create a mock OAuth provider + mock_oauth_provider = mocker.Mock(spec=OAuthCredentialsProvider) + mock_oauth_provider.oauth_token.return_value = mock_token + + # Create config with noop credentials to avoid network calls + config = Config(host="https://test.databricks.com", credentials_strategy=noop_credentials) + + # Replace the header factory with our mock + config._header_factory = mock_oauth_provider + + # Test that oauth_token() works and returns the expected token + token = config.oauth_token() + assert token == mock_token + mock_oauth_provider.oauth_token.assert_called_once() + + +def test_oauth_token_reuses_existing_provider(mocker): + """Test that oauth_token() reuses the existing OAuthCredentialsProvider.""" + from databricks.sdk.credentials_provider import OAuthCredentialsProvider + from databricks.sdk.oauth import Token + + # Create a mock OAuth token + mock_token = Token(access_token="mock_access_token", token_type="Bearer", refresh_token="mock_refresh_token") + + # Create a mock OAuth provider + mock_oauth_provider = mocker.Mock(spec=OAuthCredentialsProvider) + mock_oauth_provider.oauth_token.return_value = mock_token + + # Create config with noop credentials to avoid network calls + config = Config(host="https://test.databricks.com", credentials_strategy=noop_credentials) + + # Replace the header factory with our mock + config._header_factory = mock_oauth_provider + + # Call oauth_token() multiple times to verify reuse + token1 = config.oauth_token() + token2 = config.oauth_token() + + # Both calls should work and use the same provider instance + assert token1 == token2 == mock_token + assert mock_oauth_provider.oauth_token.call_count == 2