Skip to content
4 changes: 3 additions & 1 deletion src/databricks/sql/common/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, Optional, List, Any, TYPE_CHECKING

from databricks.sql.common.http import HttpMethod
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -67,7 +68,8 @@ def __init__(

endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
self._feature_flag_endpoint = (
f"https://{self._connection.session.host}{endpoint_suffix}"
normalize_host_with_protocol(self._connection.session.host)
+ endpoint_suffix
)

# Use the provided HTTP client
Expand Down
31 changes: 31 additions & 0 deletions src/databricks/sql/common/url_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
URL utility functions for the Databricks SQL connector.
"""


def normalize_host_with_protocol(host: str) -> str:
"""
Normalize a connection hostname by ensuring it has a protocol and removing trailing slashes.

This is useful for handling cases where users may provide hostnames with or without protocols
(common with dbt-databricks users copying URLs from their browser).

Args:
host: Connection hostname which may or may not include a protocol prefix (https:// or http://)
and may or may not have a trailing slash

Returns:
Normalized hostname with protocol prefix and no trailing slash

Examples:
normalize_host_with_protocol("myserver.com") -> "https://myserver.com"
normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com"
"""
# Remove trailing slash
host = host.rstrip("/")

# Add protocol if not present
if not host.startswith("https://") and not host.startswith("http://"):
host = f"https://{host}"

return host
3 changes: 2 additions & 1 deletion src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TelemetryPushClient,
CircuitBreakerTelemetryPushClient,
)
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -278,7 +279,7 @@ def _send_telemetry(self, events):
if self._auth_provider
else self.TELEMETRY_UNAUTHENTICATED_PATH
)
url = f"https://{self._host_url}{path}"
url = normalize_host_with_protocol(self._host_url) + path

headers = {"Accept": "application/json", "Content-Type": "application/json"}

Expand Down
36 changes: 32 additions & 4 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,31 @@ class TransactionTestSuite(unittest.TestCase):
"access_token": "tok",
}

def _setup_mock_session_with_http_client(self, mock_session):
"""
Helper to configure a mock session with HTTP client mocks.
This prevents feature flag network requests during Connection initialization.
"""
mock_session.host = "foo"

# Mock HTTP client to prevent feature flag network requests
mock_http_client = Mock()
mock_session.http_client = mock_http_client

# Mock feature flag response to prevent blocking HTTP calls
mock_ff_response = Mock()
mock_ff_response.status = 200
mock_ff_response.data = b'{"flags": [], "ttl_seconds": 900}'
mock_http_client.request.return_value = mock_ff_response

def _create_mock_connection(self, mock_session_class):
"""Helper to create a mocked connection for transaction tests."""
# Mock session
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"
mock_session.get_autocommit.return_value = True

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=False to test actual transaction functionality
Expand Down Expand Up @@ -736,9 +754,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class):
conn = self._create_mock_connection(mock_session_class)

mock_cursor = Mock()
original_error = DatabaseError(
"Original error", host_url="test-host"
)
original_error = DatabaseError("Original error", host_url="test-host")
mock_cursor.execute.side_effect = original_error

with patch.object(conn, "cursor", return_value=mock_cursor):
Expand Down Expand Up @@ -927,6 +943,8 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -959,6 +977,8 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -986,6 +1006,8 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -1015,6 +1037,8 @@ def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down Expand Up @@ -1043,6 +1067,8 @@ def test_rollback_raises_not_supported_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand All @@ -1068,6 +1094,8 @@ def test_autocommit_setter_is_noop_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_url_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Tests for URL utility functions."""
import pytest
from databricks.sql.common.url_utils import normalize_host_with_protocol


class TestNormalizeHostWithProtocol:
"""Tests for normalize_host_with_protocol function."""

@pytest.mark.parametrize("input_host,expected_output", [
# Hostname without protocol - should add https://
("myserver.com", "https://myserver.com"),
("workspace.databricks.com", "https://workspace.databricks.com"),

# Hostname with https:// - should not duplicate
("https://myserver.com", "https://myserver.com"),
("https://workspace.databricks.com", "https://workspace.databricks.com"),

# Hostname with http:// - should preserve
("http://localhost", "http://localhost"),
("http://myserver.com:8080", "http://myserver.com:8080"),

# Hostname with port numbers
("myserver.com:443", "https://myserver.com:443"),
("https://myserver.com:443", "https://myserver.com:443"),
("http://localhost:8080", "http://localhost:8080"),

# Trailing slash - should be removed
("myserver.com/", "https://myserver.com"),
("https://myserver.com/", "https://myserver.com"),
("http://localhost/", "http://localhost"),

])
def test_normalize_host_with_protocol(self, input_host, expected_output):
"""Test host normalization with various input formats."""
assert normalize_host_with_protocol(input_host) == expected_output

Loading