Skip to content

Commit 946a265

Browse files
created util method to normalise http protocol in http path (#724)
* created util method to normalise http protocol in http path Signed-off-by: Nikhil Suri <[email protected]> * Added impacted files using util method Signed-off-by: Nikhil Suri <[email protected]> * Fixed linting issues Signed-off-by: Nikhil Suri <[email protected]> * fixed broken test with mock host string Signed-off-by: Nikhil Suri <[email protected]> * mocked http client Signed-off-by: Nikhil Suri <[email protected]> * made case sensitive check in url utils Signed-off-by: Nikhil Suri <[email protected]> * linting issue resolved Signed-off-by: Nikhil Suri <[email protected]> * removed unnecessary md files Signed-off-by: Nikhil Suri <[email protected]> * made test readbale Signed-off-by: Nikhil Suri <[email protected]> * changes done in auth util as well as sea http Signed-off-by: Nikhil Suri <[email protected]> --------- Signed-off-by: Nikhil Suri <[email protected]>
1 parent 9b4e577 commit 946a265

File tree

11 files changed

+234
-61
lines changed

11 files changed

+234
-61
lines changed

src/databricks/sql/auth/auth_utils.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,6 @@
77
logger = logging.getLogger(__name__)
88

99

10-
def parse_hostname(hostname: str) -> str:
11-
"""
12-
Normalize the hostname to include scheme and trailing slash.
13-
14-
Args:
15-
hostname: The hostname to normalize
16-
17-
Returns:
18-
Normalized hostname with scheme and trailing slash
19-
"""
20-
if not hostname.startswith("http://") and not hostname.startswith("https://"):
21-
hostname = f"https://{hostname}"
22-
if not hostname.endswith("/"):
23-
hostname = f"{hostname}/"
24-
return hostname
25-
26-
2710
def decode_token(access_token: str) -> Optional[Dict]:
2811
"""
2912
Decode a JWT token without verification to extract claims.

src/databricks/sql/auth/token_federation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
from databricks.sql.auth.authenticators import AuthProvider
88
from databricks.sql.auth.auth_utils import (
9-
parse_hostname,
109
decode_token,
1110
is_same_host,
1211
)
12+
from databricks.sql.common.url_utils import normalize_host_with_protocol
1313
from databricks.sql.common.http import HttpMethod
1414

1515
logger = logging.getLogger(__name__)
@@ -99,7 +99,7 @@ def __init__(
9999
if not http_client:
100100
raise ValueError("http_client is required for TokenFederationProvider")
101101

102-
self.hostname = parse_hostname(hostname)
102+
self.hostname = normalize_host_with_protocol(hostname)
103103
self.external_provider = external_provider
104104
self.http_client = http_client
105105
self.identity_federation_client_id = identity_federation_client_id
@@ -164,7 +164,7 @@ def _should_exchange_token(self, access_token: str) -> bool:
164164

165165
def _exchange_token(self, access_token: str) -> Token:
166166
"""Exchange the external token for a Databricks token."""
167-
token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}"
167+
token_url = f"{self.hostname}{self.TOKEN_EXCHANGE_ENDPOINT}"
168168

169169
data = {
170170
"grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE,

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from databricks.sql.common.http_utils import (
1919
detect_and_parse_proxy,
2020
)
21+
from databricks.sql.common.url_utils import normalize_host_with_protocol
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -66,8 +67,9 @@ def __init__(
6667
self.auth_provider = auth_provider
6768
self.ssl_options = ssl_options
6869

69-
# Build base URL
70-
self.base_url = f"https://{server_hostname}:{self.port}"
70+
# Build base URL using url_utils for consistent normalization
71+
normalized_host = normalize_host_with_protocol(server_hostname)
72+
self.base_url = f"{normalized_host}:{self.port}"
7173

7274
# Parse URL for proxy handling
7375
parsed_url = urllib.parse.urlparse(self.base_url)

src/databricks/sql/common/feature_flag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Dict, Optional, List, Any, TYPE_CHECKING
77

88
from databricks.sql.common.http import HttpMethod
9+
from databricks.sql.common.url_utils import normalize_host_with_protocol
910

1011
if TYPE_CHECKING:
1112
from databricks.sql.client import Connection
@@ -67,7 +68,8 @@ def __init__(
6768

6869
endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
6970
self._feature_flag_endpoint = (
70-
f"https://{self._connection.session.host}{endpoint_suffix}"
71+
normalize_host_with_protocol(self._connection.session.host)
72+
+ endpoint_suffix
7173
)
7274

7375
# Use the provided HTTP client
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
URL utility functions for the Databricks SQL connector.
3+
"""
4+
5+
6+
def normalize_host_with_protocol(host: str) -> str:
7+
"""
8+
Normalize a connection hostname by ensuring it has a protocol.
9+
10+
This is useful for handling cases where users may provide hostnames with or without protocols
11+
(common with dbt-databricks users copying URLs from their browser).
12+
13+
Args:
14+
host: Connection hostname which may or may not include a protocol prefix (https:// or http://)
15+
and may or may not have a trailing slash
16+
17+
Returns:
18+
Normalized hostname with protocol prefix and no trailing slashes
19+
20+
Examples:
21+
normalize_host_with_protocol("myserver.com") -> "https://myserver.com"
22+
normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com"
23+
normalize_host_with_protocol("HTTPS://myserver.com/") -> "https://myserver.com"
24+
normalize_host_with_protocol("http://localhost:8080/") -> "http://localhost:8080"
25+
26+
Raises:
27+
ValueError: If host is None or empty string
28+
"""
29+
# Handle None or empty host
30+
if not host or not host.strip():
31+
raise ValueError("Host cannot be None or empty")
32+
33+
# Remove trailing slashes
34+
host = host.rstrip("/")
35+
36+
# Add protocol if not present (case-insensitive check)
37+
host_lower = host.lower()
38+
if not host_lower.startswith("https://") and not host_lower.startswith("http://"):
39+
host = f"https://{host}"
40+
elif host_lower.startswith("https://") or host_lower.startswith("http://"):
41+
# Normalize protocol to lowercase
42+
protocol_end = host.index("://") + 3
43+
host = host[:protocol_end].lower() + host[protocol_end:]
44+
45+
return host

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
TelemetryPushClient,
4848
CircuitBreakerTelemetryPushClient,
4949
)
50+
from databricks.sql.common.url_utils import normalize_host_with_protocol
5051

5152
if TYPE_CHECKING:
5253
from databricks.sql.client import Connection
@@ -278,7 +279,7 @@ def _send_telemetry(self, events):
278279
if self._auth_provider
279280
else self.TELEMETRY_UNAUTHENTICATED_PATH
280281
)
281-
url = f"https://{self._host_url}{path}"
282+
url = normalize_host_with_protocol(self._host_url) + path
282283

283284
headers = {"Accept": "application/json", "Content-Type": "application/json"}
284285

tests/e2e/test_circuit_breaker.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,34 @@
2323
from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager
2424

2525

26+
def wait_for_circuit_state(circuit_breaker, expected_states, timeout=5):
27+
"""
28+
Wait for circuit breaker to reach one of the expected states with polling.
29+
30+
Args:
31+
circuit_breaker: The circuit breaker instance to monitor
32+
expected_states: List of acceptable states
33+
(STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN)
34+
timeout: Maximum time to wait in seconds
35+
36+
Returns:
37+
True if state reached, False if timeout
38+
39+
Examples:
40+
# Single state - pass list with one element
41+
wait_for_circuit_state(cb, [STATE_OPEN])
42+
43+
# Multiple states
44+
wait_for_circuit_state(cb, [STATE_CLOSED, STATE_HALF_OPEN])
45+
"""
46+
start = time.time()
47+
while time.time() - start < timeout:
48+
if circuit_breaker.current_state in expected_states:
49+
return True
50+
time.sleep(0.1) # Poll every 100ms
51+
return False
52+
53+
2654
@pytest.fixture(autouse=True)
2755
def aggressive_circuit_breaker_config():
2856
"""
@@ -65,12 +93,17 @@ def create_mock_response(self, status_code):
6593
}.get(status_code, b"Response")
6694
return response
6795

68-
@pytest.mark.parametrize("status_code,should_trigger", [
69-
(429, True),
70-
(503, True),
71-
(500, False),
72-
])
73-
def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger):
96+
@pytest.mark.parametrize(
97+
"status_code,should_trigger",
98+
[
99+
(429, True),
100+
(503, True),
101+
(500, False),
102+
],
103+
)
104+
def test_circuit_breaker_triggers_for_rate_limit_codes(
105+
self, status_code, should_trigger
106+
):
74107
"""
75108
Verify circuit breaker opens for rate-limit codes (429/503) but not others (500).
76109
"""
@@ -107,9 +140,14 @@ def mock_request(*args, **kwargs):
107140
time.sleep(0.5)
108141

109142
if should_trigger:
110-
# Circuit should be OPEN after 2 rate-limit failures
143+
# Wait for circuit to open (async telemetry may take time)
144+
assert wait_for_circuit_state(
145+
circuit_breaker, [STATE_OPEN], timeout=5
146+
), f"Circuit didn't open within 5s, state: {circuit_breaker.current_state}"
147+
148+
# Circuit should be OPEN after rate-limit failures
111149
assert circuit_breaker.current_state == STATE_OPEN
112-
assert circuit_breaker.fail_counter == 2
150+
assert circuit_breaker.fail_counter >= 2 # At least 2 failures
113151

114152
# Track requests before another query
115153
requests_before = request_count["count"]
@@ -197,7 +235,10 @@ def mock_conditional_request(*args, **kwargs):
197235
cursor.fetchone()
198236
time.sleep(2)
199237

200-
assert circuit_breaker.current_state == STATE_OPEN
238+
# Wait for circuit to open
239+
assert wait_for_circuit_state(
240+
circuit_breaker, [STATE_OPEN], timeout=5
241+
), f"Circuit didn't open, state: {circuit_breaker.current_state}"
201242

202243
# Wait for reset timeout (5 seconds in test)
203244
time.sleep(6)
@@ -208,24 +249,20 @@ def mock_conditional_request(*args, **kwargs):
208249
# Execute query to trigger HALF_OPEN state
209250
cursor.execute("SELECT 3")
210251
cursor.fetchone()
211-
time.sleep(1)
212252

213-
# Circuit should be recovering
214-
assert circuit_breaker.current_state in [
215-
STATE_HALF_OPEN,
216-
STATE_CLOSED,
217-
], f"Circuit should be recovering, but is {circuit_breaker.current_state}"
253+
# Wait for circuit to start recovering
254+
assert wait_for_circuit_state(
255+
circuit_breaker, [STATE_HALF_OPEN, STATE_CLOSED], timeout=5
256+
), f"Circuit didn't recover, state: {circuit_breaker.current_state}"
218257

219258
# Execute more queries to fully recover
220259
cursor.execute("SELECT 4")
221260
cursor.fetchone()
222-
time.sleep(1)
223261

224-
current_state = circuit_breaker.current_state
225-
assert current_state in [
226-
STATE_CLOSED,
227-
STATE_HALF_OPEN,
228-
], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}"
262+
# Wait for full recovery
263+
assert wait_for_circuit_state(
264+
circuit_breaker, [STATE_CLOSED, STATE_HALF_OPEN], timeout=5
265+
), f"Circuit didn't fully recover, state: {circuit_breaker.current_state}"
229266

230267

231268
if __name__ == "__main__":

tests/unit/test_client.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,13 +646,31 @@ class TransactionTestSuite(unittest.TestCase):
646646
"access_token": "tok",
647647
}
648648

649+
def _setup_mock_session_with_http_client(self, mock_session):
650+
"""
651+
Helper to configure a mock session with HTTP client mocks.
652+
This prevents feature flag network requests during Connection initialization.
653+
"""
654+
mock_session.host = "foo"
655+
656+
# Mock HTTP client to prevent feature flag network requests
657+
mock_http_client = Mock()
658+
mock_session.http_client = mock_http_client
659+
660+
# Mock feature flag response to prevent blocking HTTP calls
661+
mock_ff_response = Mock()
662+
mock_ff_response.status = 200
663+
mock_ff_response.data = b'{"flags": [], "ttl_seconds": 900}'
664+
mock_http_client.request.return_value = mock_ff_response
665+
649666
def _create_mock_connection(self, mock_session_class):
650667
"""Helper to create a mocked connection for transaction tests."""
651-
# Mock session
652668
mock_session = Mock()
653669
mock_session.is_open = True
654670
mock_session.guid_hex = "test-session-id"
655671
mock_session.get_autocommit.return_value = True
672+
673+
self._setup_mock_session_with_http_client(mock_session)
656674
mock_session_class.return_value = mock_session
657675

658676
# Create connection with ignore_transactions=False to test actual transaction functionality
@@ -736,9 +754,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class):
736754
conn = self._create_mock_connection(mock_session_class)
737755

738756
mock_cursor = Mock()
739-
original_error = DatabaseError(
740-
"Original error", host_url="test-host"
741-
)
757+
original_error = DatabaseError("Original error", host_url="test-host")
742758
mock_cursor.execute.side_effect = original_error
743759

744760
with patch.object(conn, "cursor", return_value=mock_cursor):
@@ -927,6 +943,8 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class):
927943
mock_session = Mock()
928944
mock_session.is_open = True
929945
mock_session.guid_hex = "test-session-id"
946+
947+
self._setup_mock_session_with_http_client(mock_session)
930948
mock_session_class.return_value = mock_session
931949

932950
conn = client.Connection(
@@ -959,6 +977,8 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla
959977
mock_session = Mock()
960978
mock_session.is_open = True
961979
mock_session.guid_hex = "test-session-id"
980+
981+
self._setup_mock_session_with_http_client(mock_session)
962982
mock_session_class.return_value = mock_session
963983

964984
conn = client.Connection(
@@ -986,6 +1006,8 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla
9861006
mock_session = Mock()
9871007
mock_session.is_open = True
9881008
mock_session.guid_hex = "test-session-id"
1009+
1010+
self._setup_mock_session_with_http_client(mock_session)
9891011
mock_session_class.return_value = mock_session
9901012

9911013
conn = client.Connection(
@@ -1015,6 +1037,8 @@ def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class):
10151037
mock_session = Mock()
10161038
mock_session.is_open = True
10171039
mock_session.guid_hex = "test-session-id"
1040+
1041+
self._setup_mock_session_with_http_client(mock_session)
10181042
mock_session_class.return_value = mock_session
10191043

10201044
# Create connection with ignore_transactions=True (default)
@@ -1043,6 +1067,8 @@ def test_rollback_raises_not_supported_when_ignore_transactions_true(
10431067
mock_session = Mock()
10441068
mock_session.is_open = True
10451069
mock_session.guid_hex = "test-session-id"
1070+
1071+
self._setup_mock_session_with_http_client(mock_session)
10461072
mock_session_class.return_value = mock_session
10471073

10481074
# Create connection with ignore_transactions=True (default)
@@ -1068,6 +1094,8 @@ def test_autocommit_setter_is_noop_when_ignore_transactions_true(
10681094
mock_session = Mock()
10691095
mock_session.is_open = True
10701096
mock_session.guid_hex = "test-session-id"
1097+
1098+
self._setup_mock_session_with_http_client(mock_session)
10711099
mock_session_class.return_value = mock_session
10721100

10731101
# Create connection with ignore_transactions=True (default)

0 commit comments

Comments
 (0)