Skip to content

Commit dff96ac

Browse files
Save refreshed token in the cache
1 parent 01c3c98 commit dff96ac

File tree

2 files changed

+56
-55
lines changed

2 files changed

+56
-55
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,18 @@ def token() -> Token:
188188
def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
189189
if cfg.auth_type != 'external-browser':
190190
return None
191+
191192
client_id, client_secret = None, None
192193
if cfg.client_id:
193194
client_id = cfg.client_id
194195
client_secret = cfg.client_secret
195196
elif cfg.azure_client_id:
196197
client_id = cfg.azure_client
197198
client_secret = cfg.azure_client_secret
198-
199199
if not client_id:
200200
client_id = 'databricks-cli'
201201

202-
# Load cached credentials from disk if they exist. Note that these are
202+
# Load cached credentials from disk if they exist. Note that these are
203203
# local to the Python SDK and not reused by other SDKs.
204204
oidc_endpoints = cfg.oidc_endpoints
205205
redirect_url = 'http://localhost:8020'
@@ -211,12 +211,13 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
211211
credentials = token_cache.load()
212212
if credentials:
213213
try:
214-
# Force a refresh in case the loaded credentials are expired.
215-
# If the refresh fails, rather than throw exception we will
216-
# initiate a new OAuth login flow.
217-
credentials.token() # force a token refresh
214+
# Pro-actively refresh the loaded credentials. This is done
215+
# to detect if the token is expired and needs to be refreshed
216+
# by going through the OAuth login flow.
217+
credentials.token()
218+
token_cache.save(credentials)
218219
return credentials(cfg)
219-
# TODO: we should ideally use more specific exceptions.
220+
# TODO: We should ideally use more specific exceptions.
220221
except Exception as e:
221222
logger.warning(f'Failed to refresh cached token: {e}. Initiating new OAuth login flow')
222223

@@ -227,6 +228,7 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
227228
consent = oauth_client.initiate_consent()
228229
if not consent:
229230
return None
231+
230232
credentials = consent.launch_external_browser()
231233
token_cache.save(credentials)
232234
return credentials(cfg)

tests/test_credentials_provider.py

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,147 +1,146 @@
1-
import pytest
2-
31
from unittest.mock import Mock
2+
43
from databricks.sdk.credentials_provider import external_browser
54

5+
66
def test_external_browser_refresh_success(mocker):
77
"""Tests successful refresh of existing credentials."""
88

9-
# 1. Mock Config
9+
# Mock Config.
1010
mock_cfg = Mock()
1111
mock_cfg.auth_type = 'external-browser'
1212
mock_cfg.host = 'test-host'
1313
mock_cfg.oidc_endpoints = {'token_endpoint': 'test-token-endpoint'}
14-
mock_cfg.client_id = 'test-client-id' # Or use azure_client_id
15-
mock_cfg.client_secret = 'test-client-secret' # Or use azure_client_secret
14+
mock_cfg.client_id = 'test-client-id' # Or use azure_client_id
15+
mock_cfg.client_secret = 'test-client-secret' # Or use azure_client_secret
1616

17-
# 2. Mock TokenCache
17+
# Mock TokenCache.
1818
mock_token_cache = Mock()
1919
mock_session_credentials = Mock()
20-
mock_session_credentials.token.return_value = "valid_token" # Simulate successful refresh
20+
mock_session_credentials.token.return_value = "valid_token" # Simulate successful refresh
2121
mock_token_cache.load.return_value = mock_session_credentials
2222

23-
mock_credentials_provider = Mock()
24-
mock_session_credentials.return_value = mock_credentials_provider
23+
# Mock SessionCredentials.
24+
want_credentials_provider = lambda c: "new_credentials"
25+
mock_session_credentials.return_value = want_credentials_provider
2526

26-
# 3. Patch TokenCache (no need to mock OAuthClient in this case)
27+
# Inject the mock implementations.
2728
mocker.patch('databricks.sdk.credentials_provider.TokenCache', return_value=mock_token_cache)
2829

29-
# 4. Call the function
30-
result = external_browser(mock_cfg)
30+
got_credentials_provider = external_browser(mock_cfg)
3131

32-
# 5. Assertions
3332
mock_token_cache.load.assert_called_once()
34-
mock_session_credentials.token.assert_called_once() # Verify token refresh was attempted
35-
assert result == mock_credentials_provider
33+
mock_token_cache.save.assert_called_once_with(mock_session_credentials)
34+
mock_session_credentials.token.assert_called_once() # Verify token refresh was attempted
35+
assert got_credentials_provider == want_credentials_provider
3636

3737

3838
def test_external_browser_refresh_failure_new_oauth_flow(mocker):
3939
"""Tests failed refresh, triggering a new OAuth flow."""
4040

41-
# 1. Mock Config
41+
# Mock Config.
4242
mock_cfg = Mock()
4343
mock_cfg.auth_type = 'external-browser'
4444
mock_cfg.host = 'test-host'
4545
mock_cfg.oidc_endpoints = {'token_endpoint': 'test-token-endpoint'}
4646
mock_cfg.client_id = 'test-client-id'
4747
mock_cfg.client_secret = 'test-client-secret'
4848

49-
# 2. Mock TokenCache
49+
# Mock TokenCache.
5050
mock_token_cache = Mock()
5151
mock_session_credentials = Mock()
52-
mock_session_credentials.token.side_effect = Exception("Simulated refresh error") # Simulate a failed refresh
52+
mock_session_credentials.token.side_effect = Exception(
53+
"Simulated refresh error") # Simulate a failed refresh
5354
mock_token_cache.load.return_value = mock_session_credentials
5455

55-
mock_credentials_provider = Mock()
56-
mock_session_credentials.return_value = mock_credentials_provider
56+
# Mock SessionCredentials.
57+
want_credentials_provider = lambda c: "new_credentials"
58+
mock_session_credentials.return_value = want_credentials_provider
5759

58-
# 3. Mock OAuthClient
60+
# Mock OAuthClient.
5961
mock_oauth_client = Mock()
6062
mock_consent = Mock()
61-
mock_consent.launch_external_browser.return_value = mock_session_credentials # Simulate successful OAuth flow
63+
mock_consent.launch_external_browser.return_value = mock_session_credentials
6264
mock_oauth_client.initiate_consent.return_value = mock_consent
6365

64-
# 4. Patch TokenCache and OAuthClient
66+
# Inject the mock implementations.
6567
mocker.patch('databricks.sdk.credentials_provider.TokenCache', return_value=mock_token_cache)
6668
mocker.patch('databricks.sdk.credentials_provider.OAuthClient', return_value=mock_oauth_client)
6769

68-
# 5. Call the function
69-
result = external_browser(mock_cfg)
70+
got_credentials_provider = external_browser(mock_cfg)
7071

71-
# 6. Assertions
7272
mock_token_cache.load.assert_called_once()
73-
mock_session_credentials.token.assert_called_once() # Refresh attempt
73+
mock_session_credentials.token.assert_called_once() # Refresh attempt
7474
mock_oauth_client.initiate_consent.assert_called_once()
7575
mock_consent.launch_external_browser.assert_called_once()
7676
mock_token_cache.save.assert_called_once_with(mock_session_credentials)
77-
assert result == mock_credentials_provider
77+
assert got_credentials_provider == want_credentials_provider
7878

7979

8080
def test_external_browser_no_cached_credentials(mocker):
8181
"""Tests the case where there are no cached credentials, initiating a new OAuth flow."""
8282

83-
# 1. Mock Config
83+
# Mock Config.
8484
mock_cfg = Mock()
8585
mock_cfg.auth_type = 'external-browser'
8686
mock_cfg.host = 'test-host'
8787
mock_cfg.oidc_endpoints = {'token_endpoint': 'test-token-endpoint'}
8888
mock_cfg.client_id = 'test-client-id'
8989
mock_cfg.client_secret = 'test-client-secret'
9090

91-
# 2. Mock TokenCache
91+
# Mock TokenCache.
9292
mock_token_cache = Mock()
93-
mock_token_cache.load.return_value = None # No cached credentials
93+
mock_token_cache.load.return_value = None # No cached credentials
9494

95-
mock_session_credentials = lambda c: "new_credentials"
95+
# Mock SessionCredentials.
96+
mock_session_credentials = Mock()
97+
want_credentials_provider = lambda c: "new_credentials"
98+
mock_session_credentials.return_value = want_credentials_provider
9699

97-
# 3. Mock OAuthClient
100+
# Mock OAuthClient.
98101
mock_consent = Mock()
99102
mock_consent.launch_external_browser.return_value = mock_session_credentials
100103
mock_oauth_client = Mock()
101104
mock_oauth_client.initiate_consent.return_value = mock_consent
102105

103-
# 4. Patch TokenCache and OAuthClient
106+
# Inject the mock implementations.
104107
mocker.patch('databricks.sdk.credentials_provider.TokenCache', return_value=mock_token_cache)
105108
mocker.patch('databricks.sdk.credentials_provider.OAuthClient', return_value=mock_oauth_client)
106109

107-
# 5. Call the function
108-
result = external_browser(mock_cfg)
110+
got_credentials_provider = external_browser(mock_cfg)
109111

110-
# 6. Assertions
111112
mock_token_cache.load.assert_called_once()
112113
mock_oauth_client.initiate_consent.assert_called_once()
113114
mock_consent.launch_external_browser.assert_called_once()
114115
mock_token_cache.save.assert_called_once_with(mock_session_credentials)
115-
assert result == "new_credentials"
116+
assert got_credentials_provider == want_credentials_provider
116117

117118

118119
def test_external_browser_consent_fails(mocker):
119120
"""Tests the case where OAuth consent initiation fails."""
120121

121-
# 1. Mock Config
122+
# Mock Config.
122123
mock_cfg = Mock()
123124
mock_cfg.auth_type = 'external-browser'
124125
mock_cfg.host = 'test-host'
125126
mock_cfg.oidc_endpoints = {'token_endpoint': 'test-token-endpoint'}
126127
mock_cfg.client_id = 'test-client-id'
127128
mock_cfg.client_secret = 'test-client-secret'
128129

129-
# 2. Mock TokenCache
130+
# Mock TokenCache.
130131
mock_token_cache = Mock()
131-
mock_token_cache.load.return_value = None # No cached credentials
132+
mock_token_cache.load.return_value = None # No cached credentials
132133

133-
# 3. Mock OAuthClient
134+
# Mock OAuthClient.
134135
mock_oauth_client = Mock()
135-
mock_oauth_client.initiate_consent.return_value = None # Simulate consent failure
136+
mock_oauth_client.initiate_consent.return_value = None # Simulate consent failure
136137

137-
# 4. Patch TokenCache and OAuthClient
138+
# Inject the mock implementations.
138139
mocker.patch('databricks.sdk.credentials_provider.TokenCache', return_value=mock_token_cache)
139140
mocker.patch('databricks.sdk.credentials_provider.OAuthClient', return_value=mock_oauth_client)
140141

141-
# 5. Call the function
142-
result = external_browser(mock_cfg)
142+
got_credentials_provider = external_browser(mock_cfg)
143143

144-
# 6. Assertions
145144
mock_token_cache.load.assert_called_once()
146145
mock_oauth_client.initiate_consent.assert_called_once()
147-
assert result is None
146+
assert got_credentials_provider is None

0 commit comments

Comments
 (0)