Skip to content

Commit 0877b6f

Browse files
committed
moved pyjwt to code dependency
1 parent bef7ac6 commit 0877b6f

File tree

7 files changed

+78
-63
lines changed

7 files changed

+78
-63
lines changed

poetry.lock

Lines changed: 5 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@ pyarrow = [
2525
{ version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
2626
{ version = ">=18.0.0", python = ">=3.13", optional=true }
2727
]
28-
pyjwt = { version = "^2.0.0", optional = true }
28+
pyjwt = "^2.0.0"
2929

3030

3131
[tool.poetry.extras]
3232
pyarrow = ["pyarrow"]
33-
jwt = ["pyjwt"]
3433

3534
[tool.poetry.group.dev.dependencies]
3635
pytest = "^7.1.2"

src/databricks/sql/auth/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
DatabricksOAuthProvider,
88
AzureServicePrincipalCredentialProvider,
99
)
10-
from databricks.sql.common.auth import AuthType
10+
from databricks.sql.auth.common import AuthType
1111

1212

1313
class ClientContext:

src/databricks/sql/auth/authenticators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
ClientCredentialsTokenSource,
99
)
1010
from databricks.sql.auth.endpoint import get_oauth_endpoints
11-
from databricks.sql.common.auth import AuthType, get_effective_azure_login_app_id
11+
from databricks.sql.auth.common import AuthType, get_effective_azure_login_app_id
1212

1313
# Private API: this is an evolving interface and it will change in the future.
1414
# Please must not depend on it in your applications.
File renamed without changes.

src/databricks/sql/auth/oauth.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,19 +305,29 @@ def get_tokens(self, hostname: str, scope=None):
305305

306306

307307
class ClientCredentialsTokenSource(RefreshableTokenSource):
308+
"""
309+
A token source that uses client credentials to get a token from the token endpoint.
310+
It will refresh the token if it is expired.
311+
312+
Attributes:
313+
token_url (str): The URL of the token endpoint.
314+
oauth_client_id (str): The client ID.
315+
oauth_client_secret (str): The client secret.
316+
"""
317+
308318
def __init__(
309319
self,
310320
token_url: str,
311321
oauth_client_id: str,
312322
oauth_client_secret: str,
313-
extra_params: dict = None,
323+
extra_params: dict = {},
314324
):
315325
self.oauth_client_id = oauth_client_id
316326
self.oauth_client_secret = oauth_client_secret
317327
self.token_url = token_url
318328
self.extra_params = extra_params
319329
self.token: Token = None
320-
self._http_client = DatabricksHttpClient()
330+
self._http_client = DatabricksHttpClient.get_instance()
321331

322332
def get_token(self) -> Token:
323333
if self.token is None or self.token.is_expired():

tests/unit/test_auth.py

Lines changed: 58 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
get_python_sql_connector_auth_provider,
1616
PYSQL_OAUTH_CLIENT_ID,
1717
)
18-
from databricks.sql.auth.oauth import OAuthManager
18+
from databricks.sql.auth.oauth import OAuthManager, Token, ClientCredentialsTokenSource
1919
from databricks.sql.auth.authenticators import (
2020
DatabricksOAuthProvider,
2121
AzureServicePrincipalCredentialProvider,
22-
Token,
2322
)
2423
from databricks.sql.auth.endpoint import (
2524
CloudType,
@@ -198,16 +197,16 @@ def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
198197
self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
199198

200199

201-
class TestAzureServicePrincipalCredentialProvider:
200+
class TestClientCredentialsTokenSource:
202201
@pytest.fixture
203202
def indefinite_token(self):
204203
secret_key = "mysecret"
205204
expires_in_100_years = int(time.time()) + (100 * 365 * 24 * 60 * 60)
206205

207206
payload = {"sub": "user123", "role": "admin", "exp": expires_in_100_years}
208207

209-
token = jwt.encode(payload, secret_key, algorithm="HS256")
210-
return Token(token, "Bearer", "refresh_token")
208+
access_token = jwt.encode(payload, secret_key, algorithm="HS256")
209+
return Token(access_token, "Bearer", "refresh_token")
211210

212211
@pytest.fixture
213212
def http_response(self):
@@ -224,67 +223,75 @@ def status_response(response_status_code):
224223
return status_response
225224

226225
@pytest.fixture
227-
def provider(self):
228-
return AzureServicePrincipalCredentialProvider(
229-
client_id="dummy-client",
230-
client_secret="dummy-secret",
231-
tenant_id="dummy-tenant",
226+
def token_source(self):
227+
return ClientCredentialsTokenSource(
228+
token_url="https://token_url.com",
229+
oauth_client_id="client_id",
230+
oauth_client_secret="client_secret",
232231
)
233232

234-
def test_token_refresh(self, provider):
235-
with patch.object(provider, "_get_token") as mock_get_token:
236-
mock_get_token.return_value = Token(
237-
"access_token", "Bearer", "refresh_token"
238-
)
239-
header_factory = provider()
240-
headers = header_factory()
241-
242-
assert headers["Authorization"] == "Bearer access_token"
243-
mock_get_token.assert_called_once()
244-
245233
def test_no_token_refresh__when_token_is_not_expired(
246-
self, provider, indefinite_token
234+
self, token_source, indefinite_token
247235
):
248-
with patch.object(provider, "_get_token") as mock_get_token:
236+
with patch.object(token_source, "refresh") as mock_get_token:
249237
mock_get_token.return_value = indefinite_token
250238

251-
# Call the provider multiple times
252-
header_factory1 = provider()
253-
header_factory2 = provider()
254-
header_factory3 = provider()
255-
256-
# Get headers from each factory
257-
headers1 = header_factory1()
258-
headers2 = header_factory2()
259-
headers3 = header_factory3()
239+
# Mulitple calls for token
240+
token1 = token_source.get_token()
241+
token2 = token_source.get_token()
242+
token3 = token_source.get_token()
260243

261-
# Verify _get_token was called only once
262-
mock_get_token.assert_called_once()
244+
assert token1 == token2 == token3
245+
assert token1.access_token == indefinite_token.access_token
246+
assert token1.token_type == indefinite_token.token_type
247+
assert token1.refresh_token == indefinite_token.refresh_token
263248

264-
# Verify all headers contain the same token
265-
expected_auth_header = f"Bearer {indefinite_token.access_token}"
266-
assert headers1["Authorization"] == expected_auth_header
267-
assert headers2["Authorization"] == expected_auth_header
268-
assert headers3["Authorization"] == expected_auth_header
249+
# should refresh only once as token is not expired
250+
assert mock_get_token.call_count == 1
269251

270-
def test_get_token_success(self, provider, http_response):
271-
272-
# Patch the HTTP client's execute method
273-
with patch.object(
274-
provider._http_client, "execute", return_value=http_response(200)
275-
) as mock_execute:
276-
token = provider._get_token()
252+
def test_get_token_success(self, token_source, http_response):
253+
with patch.object(token_source._http_client, "execute") as mock_execute:
254+
mock_execute.return_value = http_response(200)
255+
token = token_source.get_token()
277256

278257
# Assert
279258
assert isinstance(token, Token)
280259
assert token.access_token == "abc123"
281260
assert token.token_type == "Bearer"
282261
assert token.refresh_token is None
283262

284-
def test_get_token_failure(self, provider, http_response):
285-
with patch.object(
286-
provider._http_client, "execute", return_value=http_response(400)
287-
) as mock_execute:
263+
def test_get_token_failure(self, token_source, http_response):
264+
with patch.object(token_source._http_client, "execute") as mock_execute:
265+
mock_execute.return_value = http_response(400)
288266
with pytest.raises(Exception) as e:
289-
provider._get_token()
267+
token_source.get_token()
290268
assert "Failed to get token: 400" in str(e.value)
269+
270+
271+
class TestAzureServicePrincipalCredentialProvider:
272+
@pytest.fixture
273+
def credential_provider(self):
274+
return AzureServicePrincipalCredentialProvider(
275+
hostname="hostname",
276+
oauth_client_id="client_id",
277+
oauth_client_secret="client_secret",
278+
azure_tenant_id="tenant_id",
279+
)
280+
281+
def test_provider_credentials(self, credential_provider):
282+
283+
test_token = Token("access_token", "Bearer", "refresh_token")
284+
285+
with patch.object(
286+
credential_provider, "get_token_source"
287+
) as mock_get_token_source:
288+
mock_get_token_source.return_value = MagicMock()
289+
mock_get_token_source.return_value.get_token.return_value = test_token
290+
291+
headers = credential_provider()()
292+
293+
assert headers["Authorization"] == f"Bearer {test_token.access_token}"
294+
assert (
295+
headers["X-Databricks-Azure-SP-Management-Token"]
296+
== test_token.access_token
297+
)

0 commit comments

Comments
 (0)