Skip to content

Commit d2f29fe

Browse files
Fix config.oauth_token() to properly rely on token caching. (#1020)
## What changes are proposed in this pull request? This PR fixes `config.oauth_token()` to actually use the caching mechanism in the credentials provider (e.g. async refresh). It also improve the error message in case the credentials provider is not an OAuth credentials provider. ## How is this tested? Unit + Integration tests.
1 parent 991bd63 commit d2f29fe

File tree

3 files changed

+88
-2
lines changed

3 files changed

+88
-2
lines changed

NEXT_CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
### Bug Fixes
88

9+
* Fix `Config.oauth_token()` to avoid re-creating a new `CredentialsProvider` at each call. This fix indirectly makes `oauth_token()` benefit from the internal caching mechanism of some providers.
10+
911
### Documentation
1012

1113
### Internal Changes

databricks/sdk/config.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from . import useragent
1414
from ._base_client import _fix_host_if_needed
1515
from .clock import Clock, RealClock
16-
from .credentials_provider import CredentialsStrategy, DefaultCredentials
16+
from .credentials_provider import (CredentialsStrategy, DefaultCredentials,
17+
OAuthCredentialsProvider)
1718
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
1819
DatabricksEnvironment, get_environment_for_hostname)
1920
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
@@ -200,7 +201,19 @@ def __init__(
200201
raise ValueError(message) from e
201202

202203
def oauth_token(self) -> Token:
203-
return self._credentials_strategy.oauth_token(self)
204+
"""Returns the OAuth token from the current credential provider.
205+
206+
This method only works when using OAuth-based authentication methods.
207+
If the current credential provider is an OAuthCredentialsProvider, it reuses
208+
the existing provider. Otherwise, it raises a ValueError indicating that
209+
OAuth tokens are not available for the current authentication method.
210+
"""
211+
if isinstance(self._header_factory, OAuthCredentialsProvider):
212+
return self._header_factory.oauth_token()
213+
raise ValueError(
214+
f"OAuth tokens are not available for {self.auth_type} authentication. "
215+
f"Use an OAuth-based authentication method to access OAuth tokens."
216+
)
204217

205218
def wrap_debug_info(self, message: str) -> str:
206219
debug_string = self.debug_string()

tests/test_config.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,74 @@ def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch):
189189
cfg = Config(host="https://abc123.azuredatabricks.net")
190190
assert cfg.azure_tenant_id == "tenant-id"
191191
assert mock.called_once
192+
193+
194+
def test_oauth_token_with_pat_auth():
195+
"""Test that oauth_token() raises an error for PAT authentication."""
196+
config = Config(host="https://test.databricks.com", token="dapi1234567890abcdef")
197+
198+
with pytest.raises(ValueError) as exc_info:
199+
config.oauth_token()
200+
201+
assert "OAuth tokens are not available for pat authentication" in str(exc_info.value)
202+
203+
204+
def test_oauth_token_with_basic_auth():
205+
"""Test that oauth_token() raises an error for basic authentication."""
206+
config = Config(host="https://test.databricks.com", username="testuser", password="testpass")
207+
208+
with pytest.raises(ValueError) as exc_info:
209+
config.oauth_token()
210+
211+
assert "OAuth tokens are not available for basic authentication" in str(exc_info.value)
212+
213+
214+
def test_oauth_token_with_oauth_provider(mocker):
215+
"""Test that oauth_token() works correctly for OAuth authentication."""
216+
from databricks.sdk.credentials_provider import OAuthCredentialsProvider
217+
from databricks.sdk.oauth import Token
218+
219+
# Create a mock OAuth token
220+
mock_token = Token(access_token="mock_access_token", token_type="Bearer", refresh_token="mock_refresh_token")
221+
222+
# Create a mock OAuth provider
223+
mock_oauth_provider = mocker.Mock(spec=OAuthCredentialsProvider)
224+
mock_oauth_provider.oauth_token.return_value = mock_token
225+
226+
# Create config with noop credentials to avoid network calls
227+
config = Config(host="https://test.databricks.com", credentials_strategy=noop_credentials)
228+
229+
# Replace the header factory with our mock
230+
config._header_factory = mock_oauth_provider
231+
232+
# Test that oauth_token() works and returns the expected token
233+
token = config.oauth_token()
234+
assert token == mock_token
235+
mock_oauth_provider.oauth_token.assert_called_once()
236+
237+
238+
def test_oauth_token_reuses_existing_provider(mocker):
239+
"""Test that oauth_token() reuses the existing OAuthCredentialsProvider."""
240+
from databricks.sdk.credentials_provider import OAuthCredentialsProvider
241+
from databricks.sdk.oauth import Token
242+
243+
# Create a mock OAuth token
244+
mock_token = Token(access_token="mock_access_token", token_type="Bearer", refresh_token="mock_refresh_token")
245+
246+
# Create a mock OAuth provider
247+
mock_oauth_provider = mocker.Mock(spec=OAuthCredentialsProvider)
248+
mock_oauth_provider.oauth_token.return_value = mock_token
249+
250+
# Create config with noop credentials to avoid network calls
251+
config = Config(host="https://test.databricks.com", credentials_strategy=noop_credentials)
252+
253+
# Replace the header factory with our mock
254+
config._header_factory = mock_oauth_provider
255+
256+
# Call oauth_token() multiple times to verify reuse
257+
token1 = config.oauth_token()
258+
token2 = config.oauth_token()
259+
260+
# Both calls should work and use the same provider instance
261+
assert token1 == token2 == mock_token
262+
assert mock_oauth_provider.oauth_token.call_count == 2

0 commit comments

Comments
 (0)