Skip to content

Commit 4437a2a

Browse files
Refactor codebase to use a unified http client
Signed-off-by: Vikrant Puppala <[email protected]>
1 parent fd81c5a commit 4437a2a

File tree

16 files changed

+440
-211
lines changed

16 files changed

+440
-211
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def get_auth_provider(cfg: ClientContext):
3535
cfg.oauth_client_id,
3636
cfg.oauth_scopes,
3737
cfg.auth_type,
38+
http_client=http_client,
3839
)
3940
elif cfg.access_token is not None:
4041
return AccessTokenAuthProvider(cfg.access_token)
@@ -53,6 +54,7 @@ def get_auth_provider(cfg: ClientContext):
5354
cfg.oauth_redirect_port_range,
5455
cfg.oauth_client_id,
5556
cfg.oauth_scopes,
57+
http_client=http_client,
5658
)
5759
else:
5860
raise RuntimeError("No valid authentication settings!")
@@ -79,7 +81,7 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
7981
)
8082

8183

82-
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
84+
def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs):
8385
# TODO : unify all the auth mechanisms with the Python SDK
8486

8587
auth_type = kwargs.get("auth_type")

src/databricks/sql/auth/authenticators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
redirect_port_range: List[int],
6464
client_id: str,
6565
scopes: List[str],
66+
http_client,
6667
auth_type: str = "databricks-oauth",
6768
):
6869
try:
@@ -79,6 +80,7 @@ def __init__(
7980
port_range=redirect_port_range,
8081
client_id=client_id,
8182
idp_endpoint=idp_endpoint,
83+
http_client=http_client,
8284
)
8385
self._hostname = hostname
8486
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes)

src/databricks/sql/auth/common.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
from typing import Optional, List
44
from urllib.parse import urlparse
5-
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod
65

76
logger = logging.getLogger(__name__)
87

@@ -36,6 +35,21 @@ def __init__(
3635
tls_client_cert_file: Optional[str] = None,
3736
oauth_persistence=None,
3837
credentials_provider=None,
38+
# HTTP client configuration parameters
39+
ssl_options=None, # SSLOptions type
40+
socket_timeout: Optional[float] = None,
41+
retry_stop_after_attempts_count: Optional[int] = None,
42+
retry_delay_min: Optional[float] = None,
43+
retry_delay_max: Optional[float] = None,
44+
retry_stop_after_attempts_duration: Optional[float] = None,
45+
retry_delay_default: Optional[float] = None,
46+
retry_dangerous_codes: Optional[List[int]] = None,
47+
http_proxy: Optional[str] = None,
48+
proxy_username: Optional[str] = None,
49+
proxy_password: Optional[str] = None,
50+
pool_connections: Optional[int] = None,
51+
pool_maxsize: Optional[int] = None,
52+
user_agent: Optional[str] = None,
3953
):
4054
self.hostname = hostname
4155
self.access_token = access_token
@@ -51,6 +65,22 @@ def __init__(
5165
self.tls_client_cert_file = tls_client_cert_file
5266
self.oauth_persistence = oauth_persistence
5367
self.credentials_provider = credentials_provider
68+
69+
# HTTP client configuration
70+
self.ssl_options = ssl_options
71+
self.socket_timeout = socket_timeout
72+
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30
73+
self.retry_delay_min = retry_delay_min or 1.0
74+
self.retry_delay_max = retry_delay_max or 60.0
75+
self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0
76+
self.retry_delay_default = retry_delay_default or 5.0
77+
self.retry_dangerous_codes = retry_dangerous_codes or []
78+
self.http_proxy = http_proxy
79+
self.proxy_username = proxy_username
80+
self.proxy_password = proxy_password
81+
self.pool_connections = pool_connections or 10
82+
self.pool_maxsize = pool_maxsize or 20
83+
self.user_agent = user_agent
5484

5585

5686
def get_effective_azure_login_app_id(hostname) -> str:
@@ -69,7 +99,7 @@ def get_effective_azure_login_app_id(hostname) -> str:
6999
return AzureAppId.PROD.value[1]
70100

71101

72-
def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
102+
def get_azure_tenant_id_from_host(host: str, http_client) -> str:
73103
"""
74104
Load the Azure tenant ID from the Azure Databricks login page.
75105
@@ -78,23 +108,22 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
78108
the Azure login page, and the tenant ID is extracted from the redirect URL.
79109
"""
80110

81-
if http_client is None:
82-
http_client = DatabricksHttpClient.get_instance()
83-
84111
login_url = f"{host}/aad/auth"
85112
logger.debug("Loading tenant ID from %s", login_url)
86-
with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp:
87-
if resp.status_code // 100 != 3:
113+
114+
with http_client.request_context('GET', login_url, allow_redirects=False) as resp:
115+
if resp.status // 100 != 3:
88116
raise ValueError(
89-
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}"
117+
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}"
90118
)
91-
entra_id_endpoint = resp.headers.get("Location")
119+
entra_id_endpoint = dict(resp.headers).get("Location")
92120
if entra_id_endpoint is None:
93121
raise ValueError(f"No Location header in response from {login_url}")
94-
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
95-
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
96-
url = urlparse(entra_id_endpoint)
97-
path_segments = url.path.split("/")
98-
if len(path_segments) < 2:
99-
raise ValueError(f"Invalid path in Location header: {url.path}")
100-
return path_segments[1]
122+
123+
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
124+
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
125+
url = urlparse(entra_id_endpoint)
126+
path_segments = url.path.split("/")
127+
if len(path_segments) < 2:
128+
raise ValueError(f"Invalid path in Location header: {url.path}")
129+
return path_segments[1]

src/databricks/sql/auth/oauth.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
from typing import List, Optional
1010

1111
import oauthlib.oauth2
12-
import requests
1312
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
14-
from requests.exceptions import RequestException
15-
from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
13+
from databricks.sql.common.http import HttpMethod, HttpHeader
1614
from databricks.sql.common.http import OAuthResponse
1715
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
1816
from databricks.sql.auth.endpoint import OAuthEndpointCollection
@@ -85,11 +83,13 @@ def __init__(
8583
port_range: List[int],
8684
client_id: str,
8785
idp_endpoint: OAuthEndpointCollection,
86+
http_client,
8887
):
8988
self.port_range = port_range
9089
self.client_id = client_id
9190
self.redirect_port = None
9291
self.idp_endpoint = idp_endpoint
92+
self.http_client = http_client
9393

9494
@staticmethod
9595
def __token_urlsafe(nbytes=32):
@@ -103,8 +103,12 @@ def __fetch_well_known_config(self, hostname: str):
103103
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)
104104

105105
try:
106-
response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth())
107-
except RequestException as e:
106+
from databricks.sql.common.unified_http_client import IgnoreNetrcAuth
107+
response = self.http_client.request('GET', url=known_config_url)
108+
# Convert urllib3 response to requests-like response for compatibility
109+
response.status_code = response.status
110+
response.json = lambda: json.loads(response.data.decode())
111+
except Exception as e:
108112
logger.error(
109113
f"Unable to fetch OAuth configuration from {known_config_url}.\n"
110114
"Verify it is a valid workspace URL and that OAuth is "
@@ -122,7 +126,7 @@ def __fetch_well_known_config(self, hostname: str):
122126
raise RuntimeError(msg)
123127
try:
124128
return response.json()
125-
except requests.exceptions.JSONDecodeError as e:
129+
except Exception as e:
126130
logger.error(
127131
f"Unable to decode OAuth configuration from {known_config_url}.\n"
128132
"Verify it is a valid workspace URL and that OAuth is "
@@ -209,10 +213,13 @@ def __send_token_request(token_request_url, data):
209213
"Accept": "application/json",
210214
"Content-Type": "application/x-www-form-urlencoded",
211215
}
212-
response = requests.post(
213-
url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth()
216+
# Use unified HTTP client
217+
from databricks.sql.common.unified_http_client import IgnoreNetrcAuth
218+
response = self.http_client.request(
219+
'POST', url=token_request_url, body=data, headers=headers
214220
)
215-
return response.json()
221+
# Convert urllib3 response to dict for compatibility
222+
return json.loads(response.data.decode())
216223

217224
def __send_refresh_token_request(self, hostname, refresh_token):
218225
oauth_config = self.__fetch_well_known_config(hostname)
@@ -320,14 +327,15 @@ def __init__(
320327
token_url,
321328
client_id,
322329
client_secret,
330+
http_client,
323331
extra_params: dict = {},
324332
):
325333
self.client_id = client_id
326334
self.client_secret = client_secret
327335
self.token_url = token_url
328336
self.extra_params = extra_params
329337
self.token: Optional[Token] = None
330-
self._http_client = DatabricksHttpClient.get_instance()
338+
self._http_client = http_client
331339

332340
def get_token(self) -> Token:
333341
if self.token is None or self.token.is_expired():

src/databricks/sql/backend/sea/queue.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def build_queue(
5050
max_download_threads: int,
5151
sea_client: SeaDatabricksClient,
5252
lz4_compressed: bool,
53+
http_client,
5354
) -> ResultSetQueue:
5455
"""
5556
Factory method to build a result set queue for SEA backend.
@@ -94,6 +95,7 @@ def build_queue(
9495
total_chunk_count=manifest.total_chunk_count,
9596
lz4_compressed=lz4_compressed,
9697
description=description,
98+
http_client=http_client,
9799
)
98100
raise ProgrammingError("Invalid result format")
99101

@@ -309,6 +311,7 @@ def __init__(
309311
sea_client: SeaDatabricksClient,
310312
statement_id: str,
311313
total_chunk_count: int,
314+
http_client,
312315
lz4_compressed: bool = False,
313316
description: List[Tuple] = [],
314317
):
@@ -337,6 +340,7 @@ def __init__(
337340
# TODO: fix these arguments when telemetry is implemented in SEA
338341
session_id_hex=None,
339342
chunk_id=0,
343+
http_client=http_client,
340344
)
341345

342346
logger.debug(

src/databricks/sql/backend/sea/result_set.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
max_download_threads=sea_client.max_download_threads,
6565
sea_client=sea_client,
6666
lz4_compressed=execute_response.lz4_compressed,
67+
http_client=connection.session.http_client,
6768
)
6869

6970
# Call parent constructor with common attributes

src/databricks/sql/client.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pyarrow
77
except ImportError:
88
pyarrow = None
9-
import requests
109
import json
1110
import os
1211
import decimal
@@ -292,6 +291,7 @@ def read(self) -> Optional[OAuthToken]:
292291
auth_provider=self.session.auth_provider,
293292
host_url=self.session.host,
294293
batch_size=self.telemetry_batch_size,
294+
http_client=self.session.http_client,
295295
)
296296

297297
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(
@@ -744,16 +744,20 @@ def _handle_staging_put(
744744
)
745745

746746
with open(local_file, "rb") as fh:
747-
r = requests.put(url=presigned_url, data=fh, headers=headers)
747+
r = self.connection.session.http_client.request('PUT', presigned_url, body=fh.read(), headers=headers)
748+
# Add compatibility attributes for urllib3 response
749+
r.status_code = r.status
750+
if hasattr(r, 'data'):
751+
r.content = r.data
752+
r.ok = r.status < 400
753+
r.text = r.data.decode() if r.data else ""
748754

749755
# fmt: off
750-
# Design borrowed from: https://stackoverflow.com/a/2342589/5093960
751-
752-
OK = requests.codes.ok # 200
753-
CREATED = requests.codes.created # 201
754-
ACCEPTED = requests.codes.accepted # 202
755-
NO_CONTENT = requests.codes.no_content # 204
756-
756+
# HTTP status codes
757+
OK = 200
758+
CREATED = 201
759+
ACCEPTED = 202
760+
NO_CONTENT = 204
757761
# fmt: on
758762

759763
if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
@@ -783,7 +787,13 @@ def _handle_staging_get(
783787
session_id_hex=self.connection.get_session_id_hex(),
784788
)
785789

786-
r = requests.get(url=presigned_url, headers=headers)
790+
r = self.connection.session.http_client.request('GET', presigned_url, headers=headers)
791+
# Add compatibility attributes for urllib3 response
792+
r.status_code = r.status
793+
if hasattr(r, 'data'):
794+
r.content = r.data
795+
r.ok = r.status < 400
796+
r.text = r.data.decode() if r.data else ""
787797

788798
# response.ok verifies the status code is not between 400-600.
789799
# Any 2xx or 3xx will evaluate r.ok == True
@@ -802,7 +812,13 @@ def _handle_staging_remove(
802812
):
803813
"""Make an HTTP DELETE request to the presigned_url"""
804814

805-
r = requests.delete(url=presigned_url, headers=headers)
815+
r = self.connection.session.http_client.request('DELETE', presigned_url, headers=headers)
816+
# Add compatibility attributes for urllib3 response
817+
r.status_code = r.status
818+
if hasattr(r, 'data'):
819+
r.content = r.data
820+
r.ok = r.status < 400
821+
r.text = r.data.decode() if r.data else ""
806822

807823
if not r.ok:
808824
raise OperationalError(

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
session_id_hex: Optional[str],
2626
statement_id: str,
2727
chunk_id: int,
28+
http_client,
2829
):
2930
self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = []
3031
self.chunk_id = chunk_id
@@ -47,6 +48,7 @@ def __init__(
4748
self._ssl_options = ssl_options
4849
self.session_id_hex = session_id_hex
4950
self.statement_id = statement_id
51+
self._http_client = http_client
5052

5153
def get_next_downloaded_file(
5254
self, next_row_offset: int
@@ -109,6 +111,7 @@ def _schedule_downloads(self):
109111
chunk_id=chunk_id,
110112
session_id_hex=self.session_id_hex,
111113
statement_id=self.statement_id,
114+
http_client=self._http_client,
112115
)
113116
task = self._thread_pool.submit(handler.run)
114117
self._download_tasks.append(task)

0 commit comments

Comments
 (0)