Skip to content
111 changes: 100 additions & 11 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ requests = "^2.18.1"
oauthlib = "^3.1.0"
openpyxl = "^3.0.10"
urllib3 = ">=1.26"
python-dateutil = "^2.8.0"
pyarrow = [
{ version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
{ version = ">=18.0.0", python = ">=3.13", optional=true }
]
python-dateutil = "^2.8.0"
pyjwt = "^2.0.0"


[tool.poetry.extras]
pyarrow = ["pyarrow"]

[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
pytest = "^7.1.2"
mypy = "^1.10.1"
pylint = ">=2.12.0"
Expand Down
47 changes: 33 additions & 14 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from enum import Enum
from typing import Optional, List

from databricks.sql.auth.authenticators import (
AuthProvider,
AccessTokenAuthProvider,
ExternalAuthProvider,
DatabricksOAuthProvider,
AzureServicePrincipalCredentialProvider,
)


class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
# other supported types (access_token) can be inferred
# we can add more types as needed later
from databricks.sql.auth.common import AuthType


class ClientContext:
Expand All @@ -24,6 +18,9 @@ def __init__(
auth_type: Optional[str] = None,
oauth_scopes: Optional[List[str]] = None,
oauth_client_id: Optional[str] = None,
oauth_client_secret: Optional[str] = None,
azure_tenant_id: Optional[str] = None,
azure_workspace_resource_id: Optional[str] = None,
oauth_redirect_port_range: Optional[List[int]] = None,
use_cert_as_auth: Optional[str] = None,
tls_client_cert_file: Optional[str] = None,
Expand All @@ -35,6 +32,9 @@ def __init__(
self.auth_type = auth_type
self.oauth_scopes = oauth_scopes
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.azure_tenant_id = azure_tenant_id
self.azure_workspace_resource_id = azure_workspace_resource_id
self.oauth_redirect_port_range = oauth_redirect_port_range
self.use_cert_as_auth = use_cert_as_auth
self.tls_client_cert_file = tls_client_cert_file
Expand All @@ -45,7 +45,17 @@ def __init__(
def get_auth_provider(cfg: ClientContext):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
return ExternalAuthProvider(
AzureServicePrincipalCredentialProvider(
cfg.hostname,
cfg.oauth_client_id,
cfg.oauth_client_secret,
cfg.azure_tenant_id,
cfg.azure_workspace_resource_id,
)
)
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
assert cfg.oauth_scopes is not None
Expand Down Expand Up @@ -103,9 +113,15 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):

def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
auth_type = kwargs.get("auth_type")
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
auth_type == AuthType.AZURE_OAUTH.value
)
client_id = kwargs.get("oauth_client_id")
redirect_port_range = kwargs.get("oauth_redirect_port_range")

if auth_type == AuthType.AZURE_SP_M2M.value:
pass
else:
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
auth_type == AuthType.AZURE_OAUTH.value
)
if kwargs.get("username") or kwargs.get("password"):
raise ValueError(
"Username/password authentication is no longer supported. "
Expand All @@ -119,9 +135,12 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
oauth_scopes=PYSQL_OAUTH_SCOPES,
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
oauth_client_id=client_id,
oauth_client_secret=kwargs.get("oauth_client_secret"),
azure_tenant_id=kwargs.get("azure_tenant_id"),
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
if client_id and kwargs.get("oauth_redirect_port")
else redirect_port_range,
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
Expand Down
93 changes: 88 additions & 5 deletions src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import abc
import base64
import logging
from typing import Callable, Dict, List

from databricks.sql.auth.oauth import OAuthManager
from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host
from typing import Callable, Dict, List, Optional
from databricks.sql.common.http import HttpHeader
from databricks.sql.auth.oauth import (
OAuthManager,
RefreshableTokenSource,
ClientCredentialsTokenSource,
)
from databricks.sql.auth.endpoint import get_oauth_endpoints
from databricks.sql.auth.common import AuthType, get_effective_azure_login_app_id

# Private API: this is an evolving interface and it will change in the future.
# Please must not depend on it in your applications.
from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence

logger = logging.getLogger(__name__)


class AuthProvider:
def add_headers(self, request_headers: Dict[str, str]):
Expand Down Expand Up @@ -146,3 +152,80 @@ def add_headers(self, request_headers: Dict[str, str]):
headers = self._header_factory()
for k, v in headers.items():
request_headers[k] = v


class AzureServicePrincipalCredentialProvider(CredentialsProvider):
"""
A credential provider for Azure Service Principal authentication with Databricks.

This class implements the CredentialsProvider protocol to authenticate requests
to Databricks REST APIs using Azure Active Directory (AAD) service principal
credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens
from Azure AD and automatically refreshes them when they expire.

Attributes:
hostname (str): The Databricks workspace hostname.
oauth_client_id (str): The Azure service principal's client ID.
oauth_client_secret (str): The Azure service principal's client secret.
azure_tenant_id (str): The Azure AD tenant ID.
azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
"""

AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
AZURE_TOKEN_ENDPOINT = "oauth2/token"

AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"

DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
"X-Databricks-Azure-Workspace-Resource-Id"
)

def __init__(
self,
hostname,
oauth_client_id,
oauth_client_secret,
azure_tenant_id,
azure_workspace_resource_id=None,
):
self.hostname = hostname
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.azure_tenant_id = azure_tenant_id
self.azure_workspace_resource_id = azure_workspace_resource_id

def auth_type(self) -> str:
return AuthType.AZURE_SP_M2M.value

def get_token_source(self, resource: str) -> RefreshableTokenSource:
return ClientCredentialsTokenSource(
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
oauth_client_id=self.oauth_client_id,
oauth_client_secret=self.oauth_client_secret,
extra_params={"resource": resource},
)

def __call__(self, *args, **kwargs) -> HeaderFactory:
inner = self.get_token_source(
resource=get_effective_azure_login_app_id(self.hostname)
)
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)

def header_factory() -> Dict[str, str]:
inner_token = inner.get_token()
cloud_token = cloud.get_token()

headers = {
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
}

if self.azure_workspace_resource_id:
headers[
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
] = self.azure_workspace_resource_id

return headers

return header_factory
28 changes: 28 additions & 0 deletions src/databricks/sql/auth/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from enum import Enum
from typing import Optional


class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
AZURE_SP_M2M = "azure-sp-m2m"


def get_effective_azure_login_app_id(hostname) -> str:
"""
Get the effective Azure login app ID for a given hostname.
This function determines the appropriate Azure login app ID based on the hostname.
If the hostname does not match any of these domains, it returns the default Databricks resource ID.
"""
azure_app_ids = {
".dev.azuredatabricks.net": "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc",
".staging.azuredatabricks.net": "4a67d088-db5c-48f1-9ff2-0aace800ae68",
}

for domain, app_id in azure_app_ids.items():
if domain in hostname:
return app_id

# default databricks resource id
return "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
108 changes: 106 additions & 2 deletions src/databricks/sql/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,63 @@
import webbrowser
from datetime import datetime, timezone
from http.server import HTTPServer
from typing import List
from typing import List, Optional

import oauthlib.oauth2
import requests
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
from requests.exceptions import RequestException

from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
from databricks.sql.common.http import OAuthResponse
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
from databricks.sql.auth.endpoint import OAuthEndpointCollection
from abc import abstractmethod, ABC
from urllib.parse import urlencode
import jwt
import time

logger = logging.getLogger(__name__)


class Token:
"""
A class to represent a token.

Attributes:
access_token (str): The access token string.
token_type (str): The type of token (e.g., "Bearer").
refresh_token (str): The refresh token string.
"""

def __init__(self, access_token: str, token_type: str, refresh_token: str):
self.access_token = access_token
self.token_type = token_type
self.refresh_token = refresh_token

def is_expired(self):
try:
decoded_token = jwt.decode(
self.access_token, options={"verify_signature": False}
)
exp_time = decoded_token.get("exp")
current_time = time.time()
buffer_time = 30 # 30 seconds buffer
return exp_time and (exp_time - buffer_time) <= current_time
except Exception as e:
logger.error("Failed to decode token: %s", e)
return e


class RefreshableTokenSource(ABC):
@abstractmethod
def get_token(self) -> Token:
pass

@abstractmethod
def refresh(self) -> Token:
pass


class IgnoreNetrcAuth(requests.auth.AuthBase):
"""This auth method is a no-op.

Expand Down Expand Up @@ -258,3 +302,63 @@ def get_tokens(self, hostname: str, scope=None):
client, token_request_url, redirect_url, code, verifier
)
return self.__get_tokens_from_response(oauth_response)


class ClientCredentialsTokenSource(RefreshableTokenSource):
"""
A token source that uses client credentials to get a token from the token endpoint.
It will refresh the token if it is expired.

Attributes:
token_url (str): The URL of the token endpoint.
oauth_client_id (str): The client ID.
oauth_client_secret (str): The client secret.
"""

def __init__(
self,
token_url,
oauth_client_id,
oauth_client_secret,
extra_params: dict = {},
):
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.token_url = token_url
self.extra_params = extra_params
self.token: Optional[Token] = None
self._http_client = DatabricksHttpClient.get_instance()

def get_token(self) -> Token:
if self.token is None or self.token.is_expired():
self.token = self.refresh()
return self.token

def refresh(self) -> Token:
headers = {
HttpHeader.CONTENT_TYPE.value: "application/x-www-form-urlencoded",
}
data = urlencode(
{
"grant_type": "client_credentials",
"client_id": self.oauth_client_id,
"client_secret": self.oauth_client_secret,
**self.extra_params,
}
)

response = self._http_client.execute(
method=HttpMethod.POST, url=self.token_url, headers=headers, data=data
)

if response.status_code == 200:
oauth_response = OAuthResponse(**response.json())
return Token(
oauth_response.access_token,
oauth_response.token_type,
oauth_response.refresh_token,
)
else:
raise Exception(
f"Failed to get token: {response.status_code} {response.text}"
)
Loading
Loading