Skip to content

Commit 4faa7ca

Browse files
Test + fmt
1 parent 8be1f12 commit 4faa7ca

File tree

2 files changed

+78
-4
lines changed

2 files changed

+78
-4
lines changed

databricks/sdk/config.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from . import useragent
1414
from ._base_client import _fix_host_if_needed
1515
from .clock import Clock, RealClock
16-
from .credentials_provider import CredentialsStrategy, DefaultCredentials, OAuthCredentialsProvider
16+
from .credentials_provider import (CredentialsStrategy, DefaultCredentials,
17+
OAuthCredentialsProvider)
1718
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
1819
DatabricksEnvironment, get_environment_for_hostname)
1920
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
@@ -201,7 +202,7 @@ def __init__(
201202

202203
def oauth_token(self) -> Token:
203204
"""Returns the OAuth token from the current credential provider.
204-
205+
205206
This method only works when using OAuth-based authentication methods.
206207
If the current credential provider is an OAuthCredentialsProvider, it reuses
207208
the existing provider. Otherwise, it raises a ValueError indicating that
@@ -211,8 +212,10 @@ def oauth_token(self) -> Token:
211212
if isinstance(self._header_factory, OAuthCredentialsProvider):
212213
return self._header_factory.oauth_token()
213214
# Raise an error for non-OAuth authentication methods
214-
raise ValueError(f"OAuth tokens are not available for {self.auth_type} authentication. "
215-
f"Use an OAuth-based authentication method to access OAuth tokens.")
215+
raise ValueError(
216+
f"OAuth tokens are not available for {self.auth_type} authentication. "
217+
f"Use an OAuth-based authentication method to access OAuth tokens."
218+
)
216219

217220
def wrap_debug_info(self, message: str) -> str:
218221
debug_string = self.debug_string()

tests/test_config.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,74 @@ def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch):
189189
cfg = Config(host="https://abc123.azuredatabricks.net")
190190
assert cfg.azure_tenant_id == "tenant-id"
191191
assert mock.called_once
192+
193+
194+
def test_oauth_token_with_pat_auth():
195+
"""Test that oauth_token() raises an error for PAT authentication."""
196+
config = Config(host="https://test.databricks.com", token="dapi1234567890abcdef")
197+
198+
with pytest.raises(ValueError) as exc_info:
199+
config.oauth_token()
200+
201+
assert "OAuth tokens are not available for pat authentication" in str(exc_info.value)
202+
203+
204+
def test_oauth_token_with_basic_auth():
205+
"""Test that oauth_token() raises an error for basic authentication."""
206+
config = Config(host="https://test.databricks.com", username="testuser", password="testpass")
207+
208+
with pytest.raises(ValueError) as exc_info:
209+
config.oauth_token()
210+
211+
assert "OAuth tokens are not available for basic authentication" in str(exc_info.value)
212+
213+
214+
def test_oauth_token_with_oauth_provider(mocker):
215+
"""Test that oauth_token() works correctly for OAuth authentication."""
216+
from databricks.sdk.credentials_provider import OAuthCredentialsProvider
217+
from databricks.sdk.oauth import Token
218+
219+
# Create a mock OAuth token
220+
mock_token = Token(access_token="mock_access_token", token_type="Bearer", refresh_token="mock_refresh_token")
221+
222+
# Create a mock OAuth provider
223+
mock_oauth_provider = mocker.Mock(spec=OAuthCredentialsProvider)
224+
mock_oauth_provider.oauth_token.return_value = mock_token
225+
226+
# Create config with mocked header factory
227+
config = Config(host="https://test.databricks.com", client_id="test-client-id", client_secret="test-client-secret")
228+
229+
# Replace the header factory with our mock
230+
config._header_factory = mock_oauth_provider
231+
232+
# Test that oauth_token() works and returns the expected token
233+
token = config.oauth_token()
234+
assert token == mock_token
235+
mock_oauth_provider.oauth_token.assert_called_once()
236+
237+
238+
def test_oauth_token_reuses_existing_provider(mocker):
239+
"""Test that oauth_token() reuses the existing OAuthCredentialsProvider."""
240+
from databricks.sdk.credentials_provider import OAuthCredentialsProvider
241+
from databricks.sdk.oauth import Token
242+
243+
# Create a mock OAuth token
244+
mock_token = Token(access_token="mock_access_token", token_type="Bearer", refresh_token="mock_refresh_token")
245+
246+
# Create a mock OAuth provider
247+
mock_oauth_provider = mocker.Mock(spec=OAuthCredentialsProvider)
248+
mock_oauth_provider.oauth_token.return_value = mock_token
249+
250+
# Create config with mocked header factory
251+
config = Config(host="https://test.databricks.com", client_id="test-client-id", client_secret="test-client-secret")
252+
253+
# Replace the header factory with our mock
254+
config._header_factory = mock_oauth_provider
255+
256+
# Call oauth_token() multiple times to verify reuse
257+
token1 = config.oauth_token()
258+
token2 = config.oauth_token()
259+
260+
# Both calls should work and use the same provider instance
261+
assert token1 == token2 == mock_token
262+
assert mock_oauth_provider.oauth_token.call_count == 2

0 commit comments

Comments
 (0)