Skip to content

Commit bef7ac6

Browse files
committed
working
1 parent 05ab3e8 commit bef7ac6

File tree

4 files changed

+213
-97
lines changed

4 files changed

+213
-97
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
1-
from enum import Enum
21
from typing import Optional, List
32

43
from databricks.sql.auth.authenticators import (
54
AuthProvider,
65
AccessTokenAuthProvider,
76
ExternalAuthProvider,
87
DatabricksOAuthProvider,
8+
AzureServicePrincipalCredentialProvider,
99
)
10-
11-
12-
class AuthType(Enum):
13-
DATABRICKS_OAUTH = "databricks-oauth"
14-
AZURE_OAUTH = "azure-oauth"
15-
# other supported types (access_token) can be inferred
16-
# we can add more types as needed later
10+
from databricks.sql.common.auth import AuthType
1711

1812

1913
class ClientContext:
@@ -24,6 +18,9 @@ def __init__(
2418
auth_type: Optional[str] = None,
2519
oauth_scopes: Optional[List[str]] = None,
2620
oauth_client_id: Optional[str] = None,
21+
oauth_client_secret: Optional[str] = None,
22+
azure_tenant_id: Optional[str] = None,
23+
azure_workspace_resource_id: Optional[str] = None,
2724
oauth_redirect_port_range: Optional[List[int]] = None,
2825
use_cert_as_auth: Optional[str] = None,
2926
tls_client_cert_file: Optional[str] = None,
@@ -35,6 +32,9 @@ def __init__(
3532
self.auth_type = auth_type
3633
self.oauth_scopes = oauth_scopes
3734
self.oauth_client_id = oauth_client_id
35+
self.oauth_client_secret = oauth_client_secret
36+
self.azure_tenant_id = azure_tenant_id
37+
self.azure_workspace_resource_id = azure_workspace_resource_id
3838
self.oauth_redirect_port_range = oauth_redirect_port_range
3939
self.use_cert_as_auth = use_cert_as_auth
4040
self.tls_client_cert_file = tls_client_cert_file
@@ -45,7 +45,17 @@ def __init__(
4545
def get_auth_provider(cfg: ClientContext):
4646
if cfg.credentials_provider:
4747
return ExternalAuthProvider(cfg.credentials_provider)
48-
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
48+
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
49+
return ExternalAuthProvider(
50+
AzureServicePrincipalCredentialProvider(
51+
cfg.hostname,
52+
cfg.oauth_client_id,
53+
cfg.oauth_client_secret,
54+
cfg.azure_tenant_id,
55+
cfg.azure_workspace_resource_id,
56+
)
57+
)
58+
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
4959
assert cfg.oauth_redirect_port_range is not None
5060
assert cfg.oauth_client_id is not None
5161
assert cfg.oauth_scopes is not None
@@ -103,9 +113,15 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
103113

104114
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
105115
auth_type = kwargs.get("auth_type")
106-
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
107-
auth_type == AuthType.AZURE_OAUTH.value
108-
)
116+
client_id = kwargs.get("oauth_client_id")
117+
redirect_port_range = kwargs.get("oauth_redirect_port_range")
118+
119+
if auth_type == AuthType.AZURE_SP_M2M.value:
120+
pass
121+
else:
122+
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
123+
auth_type == AuthType.AZURE_OAUTH.value
124+
)
109125
if kwargs.get("username") or kwargs.get("password"):
110126
raise ValueError(
111127
"Username/password authentication is no longer supported. "
@@ -119,9 +135,12 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
119135
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
120136
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
121137
oauth_scopes=PYSQL_OAUTH_SCOPES,
122-
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
138+
oauth_client_id=client_id,
139+
oauth_client_secret=kwargs.get("oauth_client_secret"),
140+
azure_tenant_id=kwargs.get("azure_tenant_id"),
141+
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
123142
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
124-
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
143+
if client_id and kwargs.get("oauth_redirect_port")
125144
else redirect_port_range,
126145
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
127146
credentials_provider=kwargs.get("credentials_provider"),

src/databricks/sql/auth/authenticators.py

Lines changed: 58 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import abc
2-
import jwt
32
import logging
4-
import time
53
from typing import Callable, Dict, List
6-
from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
7-
from databricks.sql.auth.oauth import OAuthManager
4+
from databricks.sql.common.http import HttpHeader
5+
from databricks.sql.auth.oauth import (
6+
OAuthManager,
7+
RefreshableTokenSource,
8+
ClientCredentialsTokenSource,
9+
)
810
from databricks.sql.auth.endpoint import get_oauth_endpoints
9-
from databricks.sql.common.http import DatabricksHttpClient, OAuthResponse
10-
from urllib.parse import urlencode
11+
from databricks.sql.common.auth import AuthType, get_effective_azure_login_app_id
1112

1213
# Private API: this is an evolving interface and it will change in the future.
1314
# Please must not depend on it in your applications.
@@ -38,35 +39,6 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
3839
...
3940

4041

41-
class Token:
42-
"""
43-
A class to represent a token.
44-
45-
Attributes:
46-
access_token (str): The access token string.
47-
token_type (str): The type of token (e.g., "Bearer").
48-
refresh_token (str): The refresh token string.
49-
"""
50-
51-
def __init__(self, access_token: str, token_type: str, refresh_token: str):
52-
self.access_token = access_token
53-
self.token_type = token_type
54-
self.refresh_token = refresh_token
55-
56-
def is_expired(self):
57-
try:
58-
decoded_token = jwt.decode(
59-
self.access_token, options={"verify_signature": False}
60-
)
61-
exp_time = decoded_token.get("exp")
62-
current_time = time.time()
63-
buffer_time = 30 # 30 seconds buffer
64-
return exp_time and (exp_time - buffer_time) <= current_time
65-
except Exception as e:
66-
logger.error("Failed to decode token: %s", e)
67-
return e
68-
69-
7042
# Private API: this is an evolving interface and it will change in the future.
7143
# Please must not depend on it in your applications.
7244
class AccessTokenAuthProvider(AuthProvider):
@@ -192,64 +164,68 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider):
192164
from Azure AD and automatically refreshes them when they expire.
193165
194166
Attributes:
195-
client_id (str): The Azure service principal's client ID.
196-
client_secret (str): The Azure service principal's client secret.
197-
tenant_id (str): The Azure AD tenant ID.
167+
hostname (str): The Databricks workspace hostname.
168+
oauth_client_id (str): The Azure service principal's client ID.
169+
oauth_client_secret (str): The Azure service principal's client secret.
170+
azure_tenant_id (str): The Azure AD tenant ID.
171+
azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
198172
"""
199173

200174
AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
201175
AZURE_TOKEN_ENDPOINT = "oauth2/token"
202176

203-
def __init__(self, client_id: str, client_secret: str, tenant_id: str):
204-
self.client_id = client_id
205-
self.client_secret = client_secret
206-
self.tenant_id = tenant_id
207-
self._token: Token = None
208-
self._http_client = DatabricksHttpClient.get_instance()
177+
AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"
178+
179+
DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
180+
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
181+
"X-Databricks-Azure-Workspace-Resource-Id"
182+
)
183+
184+
def __init__(
185+
self,
186+
hostname: str,
187+
oauth_client_id: str,
188+
oauth_client_secret: str,
189+
azure_tenant_id: str,
190+
azure_workspace_resource_id: str = None,
191+
):
192+
self.hostname = hostname
193+
self.oauth_client_id = oauth_client_id
194+
self.oauth_client_secret = oauth_client_secret
195+
self.azure_tenant_id = azure_tenant_id
196+
self.azure_workspace_resource_id = azure_workspace_resource_id
209197

210198
def auth_type(self) -> str:
211-
return "azure-service-principal"
199+
return AuthType.AZURE_SP_M2M.value
200+
201+
def get_token_source(self, resource: str) -> RefreshableTokenSource:
202+
return ClientCredentialsTokenSource(
203+
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
204+
oauth_client_id=self.oauth_client_id,
205+
oauth_client_secret=self.oauth_client_secret,
206+
extra_params={"resource": resource},
207+
)
212208

213209
def __call__(self, *args, **kwargs) -> HeaderFactory:
214-
def header_factory() -> Dict[str, str]:
215-
self._refresh()
216-
return {
217-
HttpHeader.AUTHORIZATION.value: f"{self._token.token_type} {self._token.access_token}",
218-
}
219-
220-
return header_factory
210+
inner = self.get_token_source(
211+
resource=get_effective_azure_login_app_id(self.hostname)
212+
)
213+
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
221214

222-
def _refresh(self) -> None:
223-
if self._token is None or self._token.is_expired():
224-
self._token = self._get_token()
215+
def header_factory() -> Dict[str, str]:
216+
inner_token = inner.get_token()
217+
cloud_token = cloud.get_token()
225218

226-
def _get_token(self) -> Token:
227-
request_url = (
228-
f"{self.AZURE_AAD_ENDPOINT}/{self.tenant_id}/{self.AZURE_TOKEN_ENDPOINT}"
229-
)
230-
headers = {
231-
HttpHeader.CONTENT_TYPE.value: "application/x-www-form-urlencoded",
232-
}
233-
data = urlencode(
234-
{
235-
"grant_type": "client_credentials",
236-
"client_id": self.client_id,
237-
"client_secret": self.client_secret,
219+
headers = {
220+
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
221+
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
238222
}
239-
)
240223

241-
response = self._http_client.execute(
242-
method=HttpMethod.POST, url=request_url, headers=headers, data=data
243-
)
224+
if self.azure_workspace_resource_id:
225+
headers[
226+
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
227+
] = self.azure_workspace_resource_id
244228

245-
if response.status_code == 200:
246-
oauth_response = OAuthResponse(**response.json())
247-
return Token(
248-
oauth_response.access_token,
249-
oauth_response.token_type,
250-
oauth_response.refresh_token,
251-
)
252-
else:
253-
raise Exception(
254-
f"Failed to get token: {response.status_code} {response.text}"
255-
)
229+
return headers
230+
231+
return header_factory

src/databricks/sql/auth/oauth.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,57 @@
1212
import requests
1313
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
1414
from requests.exceptions import RequestException
15-
15+
from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
16+
from databricks.sql.common.http import OAuthResponse
1617
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
1718
from databricks.sql.auth.endpoint import OAuthEndpointCollection
19+
from abc import abstractmethod, ABC
20+
from urllib.parse import urlencode
21+
import jwt
22+
import time
1823

1924
logger = logging.getLogger(__name__)
2025

2126

27+
class Token:
28+
"""
29+
A class to represent a token.
30+
31+
Attributes:
32+
access_token (str): The access token string.
33+
token_type (str): The type of token (e.g., "Bearer").
34+
refresh_token (str): The refresh token string.
35+
"""
36+
37+
def __init__(self, access_token: str, token_type: str, refresh_token: str):
38+
self.access_token = access_token
39+
self.token_type = token_type
40+
self.refresh_token = refresh_token
41+
42+
def is_expired(self):
43+
try:
44+
decoded_token = jwt.decode(
45+
self.access_token, options={"verify_signature": False}
46+
)
47+
exp_time = decoded_token.get("exp")
48+
current_time = time.time()
49+
buffer_time = 30 # 30 seconds buffer
50+
return exp_time and (exp_time - buffer_time) <= current_time
51+
except Exception as e:
52+
logger.error("Failed to decode token: %s", e)
53+
return e
54+
55+
56+
class RefreshableTokenSource(ABC):
57+
@abstractmethod
58+
def get_token(self) -> Token:
59+
pass
60+
61+
@abstractmethod
62+
def refresh(self):
63+
pass
64+
65+
2266
class IgnoreNetrcAuth(requests.auth.AuthBase):
2367
"""This auth method is a no-op.
2468
@@ -258,3 +302,53 @@ def get_tokens(self, hostname: str, scope=None):
258302
client, token_request_url, redirect_url, code, verifier
259303
)
260304
return self.__get_tokens_from_response(oauth_response)
305+
306+
307+
class ClientCredentialsTokenSource(RefreshableTokenSource):
308+
def __init__(
309+
self,
310+
token_url: str,
311+
oauth_client_id: str,
312+
oauth_client_secret: str,
313+
extra_params: dict = None,
314+
):
315+
self.oauth_client_id = oauth_client_id
316+
self.oauth_client_secret = oauth_client_secret
317+
self.token_url = token_url
318+
self.extra_params = extra_params
319+
self.token: Token = None
320+
self._http_client = DatabricksHttpClient()
321+
322+
def get_token(self) -> Token:
323+
if self.token is None or self.token.is_expired():
324+
self.token = self.refresh()
325+
return self.token
326+
327+
def refresh(self) -> None:
328+
headers = {
329+
HttpHeader.CONTENT_TYPE.value: "application/x-www-form-urlencoded",
330+
}
331+
data = urlencode(
332+
{
333+
"grant_type": "client_credentials",
334+
"client_id": self.oauth_client_id,
335+
"client_secret": self.oauth_client_secret,
336+
**self.extra_params,
337+
}
338+
)
339+
340+
response = self._http_client.execute(
341+
method=HttpMethod.POST, url=self.token_url, headers=headers, data=data
342+
)
343+
344+
if response.status_code == 200:
345+
oauth_response = OAuthResponse(**response.json())
346+
return Token(
347+
oauth_response.access_token,
348+
oauth_response.token_type,
349+
oauth_response.refresh_token,
350+
)
351+
else:
352+
raise Exception(
353+
f"Failed to get token: {response.status_code} {response.text}"
354+
)

0 commit comments

Comments
 (0)