Skip to content

Commit a6ff00f

Browse files
Fix sql_warehouse_name resolution: handle 'warehouses' API response key (#63286)
* Fix sql_warehouse_name resolution failing with "Can't list Databricks SQL endpoints" The _get_sql_endpoint_by_name method calls GET /api/2.0/sql/warehouses (the current API path) but checks for the "endpoints" key in the response. Since Databricks renamed SQL endpoints to SQL warehouses, the current API returns data under the "warehouses" key, causing the check to always fail. This fix handles both the current ("warehouses") and legacy ("endpoints") response keys for backward compatibility. Closes: #63285 * Use standard Python exceptions instead of AirflowException Replace AirflowException with standard Python exceptions per contributing guidelines: - RuntimeError for unexpected API response (no warehouses/endpoints key) - ValueError for warehouse name not found in results
1 parent 3996a4a commit a6ff00f

File tree

2 files changed

+94
-6
lines changed

2 files changed

+94
-6
lines changed

providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,20 @@ def _get_extra_config(self) -> dict[str, Any | None]:
129129

130130
def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
131131
result = self._do_api_call(LIST_SQL_ENDPOINTS_ENDPOINT)
132-
if "endpoints" not in result:
133-
raise AirflowException("Can't list Databricks SQL endpoints")
132+
# The API response key depends on which endpoint path is used:
133+
# - "warehouses" for the current /api/2.0/sql/warehouses path
134+
# - "endpoints" for the legacy /api/2.0/sql/endpoints path
135+
warehouses = result.get("warehouses") or result.get("endpoints")
136+
if not warehouses:
137+
raise RuntimeError(
138+
"Can't list Databricks SQL warehouses. The API response contained neither "
139+
"'warehouses' nor 'endpoints' key. Check that the connection has sufficient "
140+
"permissions to list SQL warehouses."
141+
)
134142
try:
135-
endpoint = next(endpoint for endpoint in result["endpoints"] if endpoint["name"] == endpoint_name)
143+
endpoint = next(ep for ep in warehouses if ep["name"] == endpoint_name)
136144
except StopIteration:
137-
raise AirflowException(f"Can't find Databricks SQL endpoint with name '{endpoint_name}'")
145+
raise ValueError(f"Can't find Databricks SQL warehouse with name '{endpoint_name}'")
138146
else:
139147
return endpoint
140148

providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def mock_get_requests():
8080
mock_patch = patch("airflow.providers.databricks.hooks.databricks_base.requests")
8181
mock_requests = mock_patch.start()
8282

83-
# Configure the mock object
83+
# Configure the mock object with the current API response format ("warehouses" key)
8484
mock_requests.codes.ok = 200
8585
mock_requests.get.return_value.json.return_value = {
86-
"endpoints": [
86+
"warehouses": [
8787
{
8888
"id": "1264e5078741679a",
8989
"name": "Test",
@@ -712,3 +712,83 @@ def test_get_df(df_type, df_class, description):
712712
assert df.row(1)[0] == result_sets[1][0]
713713

714714
assert isinstance(df, df_class)
715+
716+
717+
class TestGetSqlEndpointByName:
718+
"""Tests for _get_sql_endpoint_by_name with both 'warehouses' and legacy 'endpoints' API response keys."""
719+
720+
@patch("airflow.providers.databricks.hooks.databricks_base.requests")
721+
def test_resolve_warehouse_name_with_warehouses_key(self, mock_requests):
722+
"""Test that the current API response format with 'warehouses' key works."""
723+
mock_requests.codes.ok = 200
724+
mock_requests.get.return_value.json.return_value = {
725+
"warehouses": [
726+
{
727+
"id": "abc123",
728+
"name": "My Warehouse",
729+
"odbc_params": {
730+
"hostname": "xx.cloud.databricks.com",
731+
"path": "/sql/1.0/warehouses/abc123",
732+
},
733+
}
734+
]
735+
}
736+
type(mock_requests.get.return_value).status_code = PropertyMock(return_value=200)
737+
738+
hook = DatabricksSqlHook(sql_endpoint_name="My Warehouse")
739+
endpoint = hook._get_sql_endpoint_by_name("My Warehouse")
740+
assert endpoint["id"] == "abc123"
741+
assert endpoint["odbc_params"]["path"] == "/sql/1.0/warehouses/abc123"
742+
743+
@patch("airflow.providers.databricks.hooks.databricks_base.requests")
744+
def test_resolve_warehouse_name_with_legacy_endpoints_key(self, mock_requests):
745+
"""Test that the legacy API response format with 'endpoints' key still works."""
746+
mock_requests.codes.ok = 200
747+
mock_requests.get.return_value.json.return_value = {
748+
"endpoints": [
749+
{
750+
"id": "def456",
751+
"name": "Legacy Endpoint",
752+
"odbc_params": {
753+
"hostname": "xx.cloud.databricks.com",
754+
"path": "/sql/1.0/endpoints/def456",
755+
},
756+
}
757+
]
758+
}
759+
type(mock_requests.get.return_value).status_code = PropertyMock(return_value=200)
760+
761+
hook = DatabricksSqlHook(sql_endpoint_name="Legacy Endpoint")
762+
endpoint = hook._get_sql_endpoint_by_name("Legacy Endpoint")
763+
assert endpoint["id"] == "def456"
764+
assert endpoint["odbc_params"]["path"] == "/sql/1.0/endpoints/def456"
765+
766+
@patch("airflow.providers.databricks.hooks.databricks_base.requests")
767+
def test_resolve_warehouse_name_not_found(self, mock_requests):
768+
"""Test that a clear error is raised when the warehouse name doesn't match any warehouse."""
769+
mock_requests.codes.ok = 200
770+
mock_requests.get.return_value.json.return_value = {
771+
"warehouses": [
772+
{
773+
"id": "abc123",
774+
"name": "Some Other Warehouse",
775+
"odbc_params": {"path": "/sql/1.0/warehouses/abc123"},
776+
}
777+
]
778+
}
779+
type(mock_requests.get.return_value).status_code = PropertyMock(return_value=200)
780+
781+
hook = DatabricksSqlHook(sql_endpoint_name="Nonexistent Warehouse")
782+
with pytest.raises(ValueError, match="Can't find Databricks SQL warehouse with name"):
783+
hook._get_sql_endpoint_by_name("Nonexistent Warehouse")
784+
785+
@patch("airflow.providers.databricks.hooks.databricks_base.requests")
786+
def test_resolve_warehouse_name_empty_response(self, mock_requests):
787+
"""Test that a clear error is raised when the API returns no warehouses."""
788+
mock_requests.codes.ok = 200
789+
mock_requests.get.return_value.json.return_value = {}
790+
type(mock_requests.get.return_value).status_code = PropertyMock(return_value=200)
791+
792+
hook = DatabricksSqlHook(sql_endpoint_name="Test")
793+
with pytest.raises(RuntimeError, match="Can't list Databricks SQL warehouses"):
794+
hook._get_sql_endpoint_by_name("Test")

0 commit comments

Comments
 (0)