Skip to content

Commit 5e2b792

Browse files
committed
basic setup
1 parent 7696316 commit 5e2b792

File tree

6 files changed

+404
-40
lines changed

6 files changed

+404
-40
lines changed

poetry.lock

Lines changed: 101 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@ requests = "^2.18.1"
2020
oauthlib = "^3.1.0"
2121
openpyxl = "^3.0.10"
2222
urllib3 = ">=1.26"
23+
python-dateutil = "^2.8.0"
2324
pyarrow = [
2425
{ version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
2526
{ version = ">=18.0.0", python = ">=3.13", optional=true }
2627
]
27-
python-dateutil = "^2.8.0"
28+
pyjwt = { version = "^2.0.0", optional = true }
29+
2830

2931
[tool.poetry.extras]
3032
pyarrow = ["pyarrow"]
33+
jwt = ["pyjwt"]
3134

32-
[tool.poetry.dev-dependencies]
35+
[tool.poetry.group.dev.dependencies]
3336
pytest = "^7.1.2"
3437
mypy = "^1.10.1"
3538
pylint = ">=2.12.0"

src/databricks/sql/auth/authenticators.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import abc
2-
import base64
2+
import jwt
33
import logging
4+
import time
45
from typing import Callable, Dict, List
5-
6+
from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
67
from databricks.sql.auth.oauth import OAuthManager
7-
from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host
8+
from databricks.sql.auth.endpoint import get_oauth_endpoints
9+
from databricks.sql.common.http import DatabricksHttpClient, OAuthResponse
810

911
# Private API: this is an evolving interface and it will change in the future.
1012
# Please must not depend on it in your applications.
1113
from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence
1214

15+
logger = logging.getLogger(__name__)
16+
1317

1418
class AuthProvider:
1519
def add_headers(self, request_headers: Dict[str, str]):
@@ -33,6 +37,35 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
3337
...
3438

3539

40+
class Token:
41+
"""
42+
A class to represent a token.
43+
44+
Attributes:
45+
access_token (str): The access token string.
46+
token_type (str): The type of token (e.g., "Bearer").
47+
refresh_token (str): The refresh token string.
48+
"""
49+
50+
def __init__(self, access_token: str, token_type: str, refresh_token: str):
51+
self.access_token = access_token
52+
self.token_type = token_type
53+
self.refresh_token = refresh_token
54+
55+
def is_expired(self):
56+
try:
57+
decoded_token = jwt.decode(
58+
self.access_token, options={"verify_signature": False}
59+
)
60+
exp_time = decoded_token.get("exp")
61+
current_time = time.time()
62+
buffer_time = 30 # 30 seconds buffer
63+
return exp_time and (exp_time - buffer_time) <= current_time
64+
except Exception as e:
65+
logger.error("Failed to decode token: %s", e)
66+
return e
67+
68+
3669
# Private API: this is an evolving interface and it will change in the future.
3770
# Please must not depend on it in your applications.
3871
class AccessTokenAuthProvider(AuthProvider):
@@ -146,3 +179,74 @@ def add_headers(self, request_headers: Dict[str, str]):
146179
headers = self._header_factory()
147180
for k, v in headers.items():
148181
request_headers[k] = v
182+
183+
184+
class AzureServicePrincipalCredentialProvider(CredentialsProvider):
185+
"""
186+
A credential provider for Azure Service Principal authentication with Databricks.
187+
188+
This class implements the CredentialsProvider protocol to authenticate requests
189+
to Databricks REST APIs using Azure Active Directory (AAD) service principal
190+
credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens
191+
from Azure AD and automatically refreshes them when they expire.
192+
193+
Attributes:
194+
client_id (str): The Azure service principal's client ID.
195+
client_secret (str): The Azure service principal's client secret.
196+
tenant_id (str): The Azure AD tenant ID.
197+
"""
198+
199+
AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
200+
DATABRICKS_SCOPE = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d/.default"
201+
AZURE_TOKEN_ENDPOINT = "oauth2/token"
202+
203+
def __init__(self, client_id: str, client_secret: str, tenant_id: str):
204+
self.client_id = client_id
205+
self.client_secret = client_secret
206+
self.tenant_id = tenant_id
207+
self._token: Token = None
208+
self._http_client = DatabricksHttpClient.get_instance()
209+
210+
def auth_type(self) -> str:
211+
return "azure-service-principal"
212+
213+
def __call__(self, *args, **kwargs) -> HeaderFactory:
214+
def header_factory() -> Dict[str, str]:
215+
self._refresh()
216+
return {HttpHeader.AUTHORIZATION: f"Bearer {self._token.access_token}"}
217+
218+
return header_factory
219+
220+
def _refresh(self) -> None:
221+
if self._token is None or self._token.is_expired():
222+
self._token = self._get_token()
223+
224+
def _get_token(self) -> Token:
225+
request_url = (
226+
f"{self.AZURE_AAD_ENDPOINT}/{self.tenant_id}/{self.AZURE_TOKEN_ENDPOINT}"
227+
)
228+
headers = {
229+
HttpHeader.CONTENT_TYPE: "application/x-www-form-urlencoded",
230+
}
231+
data = {
232+
"grant_type": "client_credentials",
233+
"client_id": self.client_id,
234+
"client_secret": self.client_secret,
235+
"scope": self.DATABRICKS_SCOPE,
236+
}
237+
238+
response = self._http_client.execute(
239+
method=HttpMethod.POST, url=request_url, headers=headers, data=data
240+
)
241+
242+
if response.status_code == 200:
243+
oauth_response = OAuthResponse(**response.json())
244+
return Token(
245+
oauth_response.access_token,
246+
oauth_response.token_type,
247+
oauth_response.refresh_token,
248+
)
249+
else:
250+
raise Exception(
251+
f"Failed to get token: {response.status_code} {response.text}"
252+
)

src/databricks/sql/common/http.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import requests
2+
from requests.adapters import HTTPAdapter
3+
from urllib3.util.retry import Retry
4+
from enum import Enum
5+
import threading
6+
from dataclasses import dataclass
7+
8+
# Enums for HTTP Methods
9+
class HttpMethod(str, Enum):
10+
GET = "GET"
11+
POST = "POST"
12+
PUT = "PUT"
13+
DELETE = "DELETE"
14+
15+
16+
class HttpHeader(str, Enum):
17+
CONTENT_TYPE = "Content-Type"
18+
AUTHORIZATION = "Authorization"
19+
20+
21+
# Dataclass for HTTP Response
22+
@dataclass
23+
class OAuthResponse:
24+
token_type: str = ""
25+
expires_in: int = 0
26+
ext_expires_in: int = 0
27+
expires_on: int = 0
28+
not_before: int = 0
29+
resource: str = ""
30+
access_token: str = ""
31+
refresh_token: str = ""
32+
33+
34+
# Singleton class for common Http Client
35+
class DatabricksHttpClient:
36+
## TODO: Unify all the http clients in the PySQL Connector
37+
38+
_instance = None
39+
_lock = threading.Lock()
40+
41+
def __init__(self):
42+
self.session = requests.Session()
43+
adapter = HTTPAdapter(
44+
pool_connections=5,
45+
pool_maxsize=10,
46+
max_retries=Retry(total=10, backoff_factor=0.1),
47+
)
48+
self.session.mount("https://", adapter)
49+
self.session.mount("http://", adapter)
50+
51+
@classmethod
52+
def get_instance(cls) -> "DatabricksHttpClient":
53+
if cls._instance is None:
54+
with cls._lock:
55+
if cls._instance is None:
56+
cls._instance = DatabricksHttpClient()
57+
return cls._instance
58+
59+
def execute(self, method: HttpMethod, url: str, **kwargs) -> requests.Response:
60+
with self.session.request(method, url, **kwargs) as response:
61+
return response
62+
63+
def close(self):
64+
self.session.close()

tests/unit/test_auth.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
import unittest
22
import pytest
33
from typing import Optional
4-
from unittest.mock import patch
5-
4+
from unittest.mock import patch, MagicMock
5+
import jwt
66
from databricks.sql.auth.auth import (
77
AccessTokenAuthProvider,
88
AuthProvider,
99
ExternalAuthProvider,
1010
AuthType,
1111
)
12+
import time
13+
from datetime import datetime, timedelta
1214
from databricks.sql.auth.auth import (
1315
get_python_sql_connector_auth_provider,
1416
PYSQL_OAUTH_CLIENT_ID,
1517
)
1618
from databricks.sql.auth.oauth import OAuthManager
17-
from databricks.sql.auth.authenticators import DatabricksOAuthProvider
19+
from databricks.sql.auth.authenticators import (
20+
DatabricksOAuthProvider,
21+
AzureServicePrincipalCredentialProvider,
22+
Token,
23+
)
1824
from databricks.sql.auth.endpoint import (
1925
CloudType,
2026
InHouseOAuthEndpointCollection,
@@ -190,3 +196,95 @@ def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
190196
auth_provider = get_python_sql_connector_auth_provider(hostname)
191197
self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider")
192198
self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
199+
200+
201+
class TestAzureServicePrincipalCredentialProvider:
202+
@pytest.fixture
203+
def indefinite_token(self):
204+
secret_key = "mysecret"
205+
expires_in_100_years = int(time.time()) + (100 * 365 * 24 * 60 * 60)
206+
207+
payload = {"sub": "user123", "role": "admin", "exp": expires_in_100_years}
208+
209+
token = jwt.encode(payload, secret_key, algorithm="HS256")
210+
return Token(token, "Bearer", "refresh_token")
211+
212+
@pytest.fixture
213+
def http_response(self):
214+
def status_response(response_status_code):
215+
mock_response = MagicMock()
216+
mock_response.status_code = response_status_code
217+
mock_response.json.return_value = {
218+
"access_token": "abc123",
219+
"token_type": "Bearer",
220+
"refresh_token": None,
221+
}
222+
return mock_response
223+
224+
return status_response
225+
226+
@pytest.fixture
227+
def provider(self):
228+
return AzureServicePrincipalCredentialProvider(
229+
client_id="dummy-client",
230+
client_secret="dummy-secret",
231+
tenant_id="dummy-tenant",
232+
)
233+
234+
def test_token_refresh(self, provider):
235+
with patch.object(provider, "_get_token") as mock_get_token:
236+
mock_get_token.return_value = Token(
237+
"access_token", "Bearer", "refresh_token"
238+
)
239+
header_factory = provider()
240+
headers = header_factory()
241+
242+
assert headers["Authorization"] == "Bearer access_token"
243+
mock_get_token.assert_called_once()
244+
245+
def test_no_token_refresh__when_token_is_not_expired(
246+
self, provider, indefinite_token
247+
):
248+
with patch.object(provider, "_get_token") as mock_get_token:
249+
mock_get_token.return_value = indefinite_token
250+
251+
# Call the provider multiple times
252+
header_factory1 = provider()
253+
header_factory2 = provider()
254+
header_factory3 = provider()
255+
256+
# Get headers from each factory
257+
headers1 = header_factory1()
258+
headers2 = header_factory2()
259+
headers3 = header_factory3()
260+
261+
# Verify _get_token was called only once
262+
mock_get_token.assert_called_once()
263+
264+
# Verify all headers contain the same token
265+
expected_auth_header = f"Bearer {indefinite_token.access_token}"
266+
assert headers1["Authorization"] == expected_auth_header
267+
assert headers2["Authorization"] == expected_auth_header
268+
assert headers3["Authorization"] == expected_auth_header
269+
270+
def test_get_token_success(self, provider, http_response):
271+
272+
# Patch the HTTP client's execute method
273+
with patch.object(
274+
provider._http_client, "execute", return_value=http_response(200)
275+
) as mock_execute:
276+
token = provider._get_token()
277+
278+
# Assert
279+
assert isinstance(token, Token)
280+
assert token.access_token == "abc123"
281+
assert token.token_type == "Bearer"
282+
assert token.refresh_token is None
283+
284+
def test_get_token_failure(self, provider, http_response):
285+
with patch.object(
286+
provider._http_client, "execute", return_value=http_response(400)
287+
) as mock_execute:
288+
with pytest.raises(Exception) as e:
289+
provider._get_token()
290+
assert "Failed to get token: 400" in str(e.value)

0 commit comments

Comments
 (0)