|
1 | 1 | import abc |
2 | | -import jwt |
3 | 2 | import logging |
4 | | -import time |
5 | 3 | from typing import Callable, Dict, List |
6 | | -from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader |
7 | | -from databricks.sql.auth.oauth import OAuthManager |
| 4 | +from databricks.sql.common.http import HttpHeader |
| 5 | +from databricks.sql.auth.oauth import ( |
| 6 | + OAuthManager, |
| 7 | + RefreshableTokenSource, |
| 8 | + ClientCredentialsTokenSource, |
| 9 | +) |
8 | 10 | from databricks.sql.auth.endpoint import get_oauth_endpoints |
9 | | -from databricks.sql.common.http import DatabricksHttpClient, OAuthResponse |
10 | | -from urllib.parse import urlencode |
| 11 | +from databricks.sql.common.auth import AuthType, get_effective_azure_login_app_id |
11 | 12 |
|
12 | 13 | # Private API: this is an evolving interface and it will change in the future. |
13 | 14 | # Please must not depend on it in your applications. |
@@ -38,35 +39,6 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: |
38 | 39 | ... |
39 | 40 |
|
40 | 41 |
|
41 | | -class Token: |
42 | | - """ |
43 | | - A class to represent a token. |
44 | | -
|
45 | | - Attributes: |
46 | | - access_token (str): The access token string. |
47 | | - token_type (str): The type of token (e.g., "Bearer"). |
48 | | - refresh_token (str): The refresh token string. |
49 | | - """ |
50 | | - |
51 | | - def __init__(self, access_token: str, token_type: str, refresh_token: str): |
52 | | - self.access_token = access_token |
53 | | - self.token_type = token_type |
54 | | - self.refresh_token = refresh_token |
55 | | - |
56 | | - def is_expired(self): |
57 | | - try: |
58 | | - decoded_token = jwt.decode( |
59 | | - self.access_token, options={"verify_signature": False} |
60 | | - ) |
61 | | - exp_time = decoded_token.get("exp") |
62 | | - current_time = time.time() |
63 | | - buffer_time = 30 # 30 seconds buffer |
64 | | - return exp_time and (exp_time - buffer_time) <= current_time |
65 | | - except Exception as e: |
66 | | - logger.error("Failed to decode token: %s", e) |
67 | | - return e |
68 | | - |
69 | | - |
70 | 42 | # Private API: this is an evolving interface and it will change in the future. |
71 | 43 | # Please must not depend on it in your applications. |
72 | 44 | class AccessTokenAuthProvider(AuthProvider): |
@@ -192,64 +164,68 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider): |
192 | 164 | from Azure AD and automatically refreshes them when they expire. |
193 | 165 |
|
194 | 166 | Attributes: |
195 | | - client_id (str): The Azure service principal's client ID. |
196 | | - client_secret (str): The Azure service principal's client secret. |
197 | | - tenant_id (str): The Azure AD tenant ID. |
| 167 | + hostname (str): The Databricks workspace hostname. |
| 168 | + oauth_client_id (str): The Azure service principal's client ID. |
| 169 | + oauth_client_secret (str): The Azure service principal's client secret. |
| 170 | + azure_tenant_id (str): The Azure AD tenant ID. |
| 171 | + azure_workspace_resource_id (str, optional): The Azure workspace resource ID. |
198 | 172 | """ |
199 | 173 |
|
200 | 174 | AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com" |
201 | 175 | AZURE_TOKEN_ENDPOINT = "oauth2/token" |
202 | 176 |
|
203 | | - def __init__(self, client_id: str, client_secret: str, tenant_id: str): |
204 | | - self.client_id = client_id |
205 | | - self.client_secret = client_secret |
206 | | - self.tenant_id = tenant_id |
207 | | - self._token: Token = None |
208 | | - self._http_client = DatabricksHttpClient.get_instance() |
| 177 | + AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/" |
| 178 | + |
| 179 | + DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token" |
| 180 | + DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = ( |
| 181 | + "X-Databricks-Azure-Workspace-Resource-Id" |
| 182 | + ) |
| 183 | + |
| 184 | + def __init__( |
| 185 | + self, |
| 186 | + hostname: str, |
| 187 | + oauth_client_id: str, |
| 188 | + oauth_client_secret: str, |
| 189 | + azure_tenant_id: str, |
| 190 | + azure_workspace_resource_id: str = None, |
| 191 | + ): |
| 192 | + self.hostname = hostname |
| 193 | + self.oauth_client_id = oauth_client_id |
| 194 | + self.oauth_client_secret = oauth_client_secret |
| 195 | + self.azure_tenant_id = azure_tenant_id |
| 196 | + self.azure_workspace_resource_id = azure_workspace_resource_id |
209 | 197 |
|
210 | 198 | def auth_type(self) -> str: |
211 | | - return "azure-service-principal" |
| 199 | + return AuthType.AZURE_SP_M2M.value |
| 200 | + |
| 201 | + def get_token_source(self, resource: str) -> RefreshableTokenSource: |
| 202 | + return ClientCredentialsTokenSource( |
| 203 | + token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}", |
| 204 | + oauth_client_id=self.oauth_client_id, |
| 205 | + oauth_client_secret=self.oauth_client_secret, |
| 206 | + extra_params={"resource": resource}, |
| 207 | + ) |
212 | 208 |
|
213 | 209 | def __call__(self, *args, **kwargs) -> HeaderFactory: |
214 | | - def header_factory() -> Dict[str, str]: |
215 | | - self._refresh() |
216 | | - return { |
217 | | - HttpHeader.AUTHORIZATION.value: f"{self._token.token_type} {self._token.access_token}", |
218 | | - } |
219 | | - |
220 | | - return header_factory |
| 210 | + inner = self.get_token_source( |
| 211 | + resource=get_effective_azure_login_app_id(self.hostname) |
| 212 | + ) |
| 213 | + cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE) |
221 | 214 |
|
222 | | - def _refresh(self) -> None: |
223 | | - if self._token is None or self._token.is_expired(): |
224 | | - self._token = self._get_token() |
| 215 | + def header_factory() -> Dict[str, str]: |
| 216 | + inner_token = inner.get_token() |
| 217 | + cloud_token = cloud.get_token() |
225 | 218 |
|
226 | | - def _get_token(self) -> Token: |
227 | | - request_url = ( |
228 | | - f"{self.AZURE_AAD_ENDPOINT}/{self.tenant_id}/{self.AZURE_TOKEN_ENDPOINT}" |
229 | | - ) |
230 | | - headers = { |
231 | | - HttpHeader.CONTENT_TYPE.value: "application/x-www-form-urlencoded", |
232 | | - } |
233 | | - data = urlencode( |
234 | | - { |
235 | | - "grant_type": "client_credentials", |
236 | | - "client_id": self.client_id, |
237 | | - "client_secret": self.client_secret, |
| 219 | + headers = { |
| 220 | + HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}", |
| 221 | + self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token, |
238 | 222 | } |
239 | | - ) |
240 | 223 |
|
241 | | - response = self._http_client.execute( |
242 | | - method=HttpMethod.POST, url=request_url, headers=headers, data=data |
243 | | - ) |
| 224 | + if self.azure_workspace_resource_id: |
| 225 | + headers[ |
| 226 | + self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER |
| 227 | + ] = self.azure_workspace_resource_id |
244 | 228 |
|
245 | | - if response.status_code == 200: |
246 | | - oauth_response = OAuthResponse(**response.json()) |
247 | | - return Token( |
248 | | - oauth_response.access_token, |
249 | | - oauth_response.token_type, |
250 | | - oauth_response.refresh_token, |
251 | | - ) |
252 | | - else: |
253 | | - raise Exception( |
254 | | - f"Failed to get token: {response.status_code} {response.text}" |
255 | | - ) |
| 229 | + return headers |
| 230 | + |
| 231 | + return header_factory |
0 commit comments