1515 get_python_sql_connector_auth_provider ,
1616 PYSQL_OAUTH_CLIENT_ID ,
1717)
18- from databricks .sql .auth .oauth import OAuthManager
18+ from databricks .sql .auth .oauth import OAuthManager , Token , ClientCredentialsTokenSource
1919from databricks .sql .auth .authenticators import (
2020 DatabricksOAuthProvider ,
2121 AzureServicePrincipalCredentialProvider ,
22- Token ,
2322)
2423from databricks .sql .auth .endpoint import (
2524 CloudType ,
@@ -198,16 +197,16 @@ def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
198197 self .assertTrue (auth_provider ._client_id , PYSQL_OAUTH_CLIENT_ID )
199198
200199
201- class TestAzureServicePrincipalCredentialProvider :
200+ class TestClientCredentialsTokenSource :
202201 @pytest .fixture
203202 def indefinite_token (self ):
204203 secret_key = "mysecret"
205204 expires_in_100_years = int (time .time ()) + (100 * 365 * 24 * 60 * 60 )
206205
207206 payload = {"sub" : "user123" , "role" : "admin" , "exp" : expires_in_100_years }
208207
209- token = jwt .encode (payload , secret_key , algorithm = "HS256" )
210- return Token (token , "Bearer" , "refresh_token" )
208+ access_token = jwt .encode (payload , secret_key , algorithm = "HS256" )
209+ return Token (access_token , "Bearer" , "refresh_token" )
211210
212211 @pytest .fixture
213212 def http_response (self ):
@@ -224,67 +223,75 @@ def status_response(response_status_code):
224223 return status_response
225224
226225 @pytest .fixture
227- def provider (self ):
228- return AzureServicePrincipalCredentialProvider (
229- client_id = "dummy-client " ,
230- client_secret = "dummy-secret " ,
231- tenant_id = "dummy-tenant " ,
226+ def token_source (self ):
227+ return ClientCredentialsTokenSource (
228+ token_url = "https://token_url.com " ,
229+ oauth_client_id = "client_id " ,
230+ oauth_client_secret = "client_secret " ,
232231 )
233232
234- def test_token_refresh (self , provider ):
235- with patch .object (provider , "_get_token" ) as mock_get_token :
236- mock_get_token .return_value = Token (
237- "access_token" , "Bearer" , "refresh_token"
238- )
239- header_factory = provider ()
240- headers = header_factory ()
241-
242- assert headers ["Authorization" ] == "Bearer access_token"
243- mock_get_token .assert_called_once ()
244-
245233 def test_no_token_refresh__when_token_is_not_expired (
246- self , provider , indefinite_token
234+ self , token_source , indefinite_token
247235 ):
248- with patch .object (provider , "_get_token " ) as mock_get_token :
236+ with patch .object (token_source , "refresh " ) as mock_get_token :
249237 mock_get_token .return_value = indefinite_token
250238
251- # Call the provider multiple times
252- header_factory1 = provider ()
253- header_factory2 = provider ()
254- header_factory3 = provider ()
255-
256- # Get headers from each factory
257- headers1 = header_factory1 ()
258- headers2 = header_factory2 ()
259- headers3 = header_factory3 ()
239+ # Mulitple calls for token
240+ token1 = token_source .get_token ()
241+ token2 = token_source .get_token ()
242+ token3 = token_source .get_token ()
260243
261- # Verify _get_token was called only once
262- mock_get_token .assert_called_once ()
244+ assert token1 == token2 == token3
245+ assert token1 .access_token == indefinite_token .access_token
246+ assert token1 .token_type == indefinite_token .token_type
247+ assert token1 .refresh_token == indefinite_token .refresh_token
263248
264- # Verify all headers contain the same token
265- expected_auth_header = f"Bearer { indefinite_token .access_token } "
266- assert headers1 ["Authorization" ] == expected_auth_header
267- assert headers2 ["Authorization" ] == expected_auth_header
268- assert headers3 ["Authorization" ] == expected_auth_header
249+ # should refresh only once as token is not expired
250+ assert mock_get_token .call_count == 1
269251
270- def test_get_token_success (self , provider , http_response ):
271-
272- # Patch the HTTP client's execute method
273- with patch .object (
274- provider ._http_client , "execute" , return_value = http_response (200 )
275- ) as mock_execute :
276- token = provider ._get_token ()
252+ def test_get_token_success (self , token_source , http_response ):
253+ with patch .object (token_source ._http_client , "execute" ) as mock_execute :
254+ mock_execute .return_value = http_response (200 )
255+ token = token_source .get_token ()
277256
278257 # Assert
279258 assert isinstance (token , Token )
280259 assert token .access_token == "abc123"
281260 assert token .token_type == "Bearer"
282261 assert token .refresh_token is None
283262
284- def test_get_token_failure (self , provider , http_response ):
285- with patch .object (
286- provider ._http_client , "execute" , return_value = http_response (400 )
287- ) as mock_execute :
263+ def test_get_token_failure (self , token_source , http_response ):
264+ with patch .object (token_source ._http_client , "execute" ) as mock_execute :
265+ mock_execute .return_value = http_response (400 )
288266 with pytest .raises (Exception ) as e :
289- provider . _get_token ()
267+ token_source . get_token ()
290268 assert "Failed to get token: 400" in str (e .value )
269+
270+
271+ class TestAzureServicePrincipalCredentialProvider :
272+ @pytest .fixture
273+ def credential_provider (self ):
274+ return AzureServicePrincipalCredentialProvider (
275+ hostname = "hostname" ,
276+ oauth_client_id = "client_id" ,
277+ oauth_client_secret = "client_secret" ,
278+ azure_tenant_id = "tenant_id" ,
279+ )
280+
281+ def test_provider_credentials (self , credential_provider ):
282+
283+ test_token = Token ("access_token" , "Bearer" , "refresh_token" )
284+
285+ with patch .object (
286+ credential_provider , "get_token_source"
287+ ) as mock_get_token_source :
288+ mock_get_token_source .return_value = MagicMock ()
289+ mock_get_token_source .return_value .get_token .return_value = test_token
290+
291+ headers = credential_provider ()()
292+
293+ assert headers ["Authorization" ] == f"Bearer { test_token .access_token } "
294+ assert (
295+ headers ["X-Databricks-Azure-SP-Management-Token" ]
296+ == test_token .access_token
297+ )
0 commit comments