Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_auth_provider(cfg: ClientContext):
cfg.oauth_client_id,
cfg.oauth_scopes,
cfg.auth_type,
http_client=http_client,
)
elif cfg.access_token is not None:
return AccessTokenAuthProvider(cfg.access_token)
Expand All @@ -53,6 +54,7 @@ def get_auth_provider(cfg: ClientContext):
cfg.oauth_redirect_port_range,
cfg.oauth_client_id,
cfg.oauth_scopes,
http_client=http_client,
)
else:
raise RuntimeError("No valid authentication settings!")
Expand All @@ -79,7 +81,7 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
)


def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs):
# TODO : unify all the auth mechanisms with the Python SDK

auth_type = kwargs.get("auth_type")
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
redirect_port_range: List[int],
client_id: str,
scopes: List[str],
http_client,
auth_type: str = "databricks-oauth",
):
try:
Expand All @@ -79,6 +80,7 @@ def __init__(
port_range=redirect_port_range,
client_id=client_id,
idp_endpoint=idp_endpoint,
http_client=http_client,
)
self._hostname = hostname
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes)
Expand Down
61 changes: 45 additions & 16 deletions src/databricks/sql/auth/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
from typing import Optional, List
from urllib.parse import urlparse
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -36,6 +35,21 @@ def __init__(
tls_client_cert_file: Optional[str] = None,
oauth_persistence=None,
credentials_provider=None,
# HTTP client configuration parameters
ssl_options=None, # SSLOptions type
socket_timeout: Optional[float] = None,
retry_stop_after_attempts_count: Optional[int] = None,
retry_delay_min: Optional[float] = None,
retry_delay_max: Optional[float] = None,
retry_stop_after_attempts_duration: Optional[float] = None,
retry_delay_default: Optional[float] = None,
retry_dangerous_codes: Optional[List[int]] = None,
http_proxy: Optional[str] = None,
proxy_username: Optional[str] = None,
proxy_password: Optional[str] = None,
pool_connections: Optional[int] = None,
pool_maxsize: Optional[int] = None,
user_agent: Optional[str] = None,
):
self.hostname = hostname
self.access_token = access_token
Expand All @@ -51,6 +65,22 @@ def __init__(
self.tls_client_cert_file = tls_client_cert_file
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider

# HTTP client configuration
self.ssl_options = ssl_options
self.socket_timeout = socket_timeout
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30
self.retry_delay_min = retry_delay_min or 1.0
self.retry_delay_max = retry_delay_max or 60.0
self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0
self.retry_delay_default = retry_delay_default or 5.0
self.retry_dangerous_codes = retry_dangerous_codes or []
self.http_proxy = http_proxy
self.proxy_username = proxy_username
self.proxy_password = proxy_password
self.pool_connections = pool_connections or 10
self.pool_maxsize = pool_maxsize or 20
self.user_agent = user_agent


def get_effective_azure_login_app_id(hostname) -> str:
Expand All @@ -69,7 +99,7 @@ def get_effective_azure_login_app_id(hostname) -> str:
return AzureAppId.PROD.value[1]


def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
def get_azure_tenant_id_from_host(host: str, http_client) -> str:
"""
Load the Azure tenant ID from the Azure Databricks login page.

Expand All @@ -78,23 +108,22 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
the Azure login page, and the tenant ID is extracted from the redirect URL.
"""

if http_client is None:
http_client = DatabricksHttpClient.get_instance()

login_url = f"{host}/aad/auth"
logger.debug("Loading tenant ID from %s", login_url)
with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp:
if resp.status_code // 100 != 3:

with http_client.request_context('GET', login_url, allow_redirects=False) as resp:
if resp.status // 100 != 3:
raise ValueError(
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}"
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}"
)
entra_id_endpoint = resp.headers.get("Location")
entra_id_endpoint = dict(resp.headers).get("Location")
if entra_id_endpoint is None:
raise ValueError(f"No Location header in response from {login_url}")
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
url = urlparse(entra_id_endpoint)
path_segments = url.path.split("/")
if len(path_segments) < 2:
raise ValueError(f"Invalid path in Location header: {url.path}")
return path_segments[1]

# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
url = urlparse(entra_id_endpoint)
path_segments = url.path.split("/")
if len(path_segments) < 2:
raise ValueError(f"Invalid path in Location header: {url.path}")
return path_segments[1]
28 changes: 18 additions & 10 deletions src/databricks/sql/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from typing import List, Optional

import oauthlib.oauth2
import requests
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
from requests.exceptions import RequestException
from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
from databricks.sql.common.http import HttpMethod, HttpHeader
from databricks.sql.common.http import OAuthResponse
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
from databricks.sql.auth.endpoint import OAuthEndpointCollection
Expand Down Expand Up @@ -85,11 +83,13 @@ def __init__(
port_range: List[int],
client_id: str,
idp_endpoint: OAuthEndpointCollection,
http_client,
):
self.port_range = port_range
self.client_id = client_id
self.redirect_port = None
self.idp_endpoint = idp_endpoint
self.http_client = http_client

@staticmethod
def __token_urlsafe(nbytes=32):
Expand All @@ -103,8 +103,12 @@ def __fetch_well_known_config(self, hostname: str):
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)

try:
response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth())
except RequestException as e:
from databricks.sql.common.unified_http_client import IgnoreNetrcAuth
response = self.http_client.request('GET', url=known_config_url)
# Convert urllib3 response to requests-like response for compatibility
response.status_code = response.status
response.json = lambda: json.loads(response.data.decode())
except Exception as e:
logger.error(
f"Unable to fetch OAuth configuration from {known_config_url}.\n"
"Verify it is a valid workspace URL and that OAuth is "
Expand All @@ -122,7 +126,7 @@ def __fetch_well_known_config(self, hostname: str):
raise RuntimeError(msg)
try:
return response.json()
except requests.exceptions.JSONDecodeError as e:
except Exception as e:
logger.error(
f"Unable to decode OAuth configuration from {known_config_url}.\n"
"Verify it is a valid workspace URL and that OAuth is "
Expand Down Expand Up @@ -209,10 +213,13 @@ def __send_token_request(token_request_url, data):
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
response = requests.post(
url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth()
# Use unified HTTP client
from databricks.sql.common.unified_http_client import IgnoreNetrcAuth
response = self.http_client.request(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is a static method using class http_client ?

'POST', url=token_request_url, body=data, headers=headers
)
return response.json()
# Convert urllib3 response to dict for compatibility
return json.loads(response.data.decode())

def __send_refresh_token_request(self, hostname, refresh_token):
oauth_config = self.__fetch_well_known_config(hostname)
Expand Down Expand Up @@ -320,14 +327,15 @@ def __init__(
token_url,
client_id,
client_secret,
http_client,
extra_params: dict = {},
):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
self.extra_params = extra_params
self.token: Optional[Token] = None
self._http_client = DatabricksHttpClient.get_instance()
self._http_client = http_client

def get_token(self) -> Token:
if self.token is None or self.token.is_expired():
Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/backend/sea/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def build_queue(
max_download_threads: int,
sea_client: SeaDatabricksClient,
lz4_compressed: bool,
http_client,
) -> ResultSetQueue:
"""
Factory method to build a result set queue for SEA backend.
Expand Down Expand Up @@ -94,6 +95,7 @@ def build_queue(
total_chunk_count=manifest.total_chunk_count,
lz4_compressed=lz4_compressed,
description=description,
http_client=http_client,
)
raise ProgrammingError("Invalid result format")

Expand Down Expand Up @@ -309,6 +311,7 @@ def __init__(
sea_client: SeaDatabricksClient,
statement_id: str,
total_chunk_count: int,
http_client,
lz4_compressed: bool = False,
description: List[Tuple] = [],
):
Expand Down Expand Up @@ -337,6 +340,7 @@ def __init__(
# TODO: fix these arguments when telemetry is implemented in SEA
session_id_hex=None,
chunk_id=0,
http_client=http_client,
)

logger.debug(
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
max_download_threads=sea_client.max_download_threads,
sea_client=sea_client,
lz4_compressed=execute_response.lz4_compressed,
http_client=connection.session.http_client,
)

# Call parent constructor with common attributes
Expand Down
38 changes: 27 additions & 11 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pyarrow
except ImportError:
pyarrow = None
import requests
import json
import os
import decimal
Expand Down Expand Up @@ -292,6 +291,7 @@ def read(self) -> Optional[OAuthToken]:
auth_provider=self.session.auth_provider,
host_url=self.session.host,
batch_size=self.telemetry_batch_size,
http_client=self.session.http_client,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we are not keeping different client for Telemetry ( I believe there is different retry behaviour for it )

)

self._telemetry_client = TelemetryClientFactory.get_telemetry_client(
Expand Down Expand Up @@ -744,16 +744,20 @@ def _handle_staging_put(
)

with open(local_file, "rb") as fh:
r = requests.put(url=presigned_url, data=fh, headers=headers)
r = self.connection.session.http_client.request('PUT', presigned_url, body=fh.read(), headers=headers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sreekanth-db Can you discuss with @vikrantpuppala and integrate the new common client approach in volume operations

# Add compatibility attributes for urllib3 response
r.status_code = r.status
if hasattr(r, 'data'):
r.content = r.data
r.ok = r.status < 400
r.text = r.data.decode() if r.data else ""

# fmt: off
# Design borrowed from: https://stackoverflow.com/a/2342589/5093960

OK = requests.codes.ok # 200
CREATED = requests.codes.created # 201
ACCEPTED = requests.codes.accepted # 202
NO_CONTENT = requests.codes.no_content # 204

# HTTP status codes
OK = 200
CREATED = 201
ACCEPTED = 202
NO_CONTENT = 204
# fmt: on

if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
Expand Down Expand Up @@ -783,7 +787,13 @@ def _handle_staging_get(
session_id_hex=self.connection.get_session_id_hex(),
)

r = requests.get(url=presigned_url, headers=headers)
r = self.connection.session.http_client.request('GET', presigned_url, headers=headers)
# Add compatibility attributes for urllib3 response
r.status_code = r.status
if hasattr(r, 'data'):
r.content = r.data
r.ok = r.status < 400
r.text = r.data.decode() if r.data else ""

# response.ok verifies the status code is not between 400-600.
# Any 2xx or 3xx will evaluate r.ok == True
Expand All @@ -802,7 +812,13 @@ def _handle_staging_remove(
):
"""Make an HTTP DELETE request to the presigned_url"""

r = requests.delete(url=presigned_url, headers=headers)
r = self.connection.session.http_client.request('DELETE', presigned_url, headers=headers)
# Add compatibility attributes for urllib3 response
r.status_code = r.status
if hasattr(r, 'data'):
r.content = r.data
r.ok = r.status < 400
r.text = r.data.decode() if r.data else ""

if not r.ok:
raise OperationalError(
Expand Down
3 changes: 3 additions & 0 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
session_id_hex: Optional[str],
statement_id: str,
chunk_id: int,
http_client,
):
self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = []
self.chunk_id = chunk_id
Expand All @@ -47,6 +48,7 @@ def __init__(
self._ssl_options = ssl_options
self.session_id_hex = session_id_hex
self.statement_id = statement_id
self._http_client = http_client

def get_next_downloaded_file(
self, next_row_offset: int
Expand Down Expand Up @@ -109,6 +111,7 @@ def _schedule_downloads(self):
chunk_id=chunk_id,
session_id_hex=self.session_id_hex,
statement_id=self.statement_id,
http_client=self._http_client,
)
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)
Expand Down
Loading
Loading