Skip to content

Commit 69aa772

Browse files
authored
Extract poll refresh pipeline from cursor (#849)
1 parent 5066e76 commit 69aa772

File tree

6 files changed

+146
-167
lines changed

6 files changed

+146
-167
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
- Fix behavior flag use in init of DatabricksAdapter (thanks @VersusFacit!) ([836](https://github.com/databricks/dbt-databricks/pull/836))
3636
- Restrict pydantic to V1 per dbt Labs' request ([843](https://github.com/databricks/dbt-databricks/pull/843))
3737
- Switching to Ruff for formatting and linting ([847](https://github.com/databricks/dbt-databricks/pull/847))
38+
- Refactoring location of DLT polling code ([849](https://github.com/databricks/dbt-databricks/pull/849))
3839
- Switching to Hatch and pyproject.toml for project config ([853](https://github.com/databricks/dbt-databricks/pull/853))
3940

4041
## dbt-databricks 1.8.7 (October 10, 2024)

dbt/adapters/databricks/api_client.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,60 @@ def run(self, job_id: str, enable_queueing: bool = True) -> str:
460460
return response_json["run_id"]
461461

462462

463+
class DltPipelineApi(PollableApi):
464+
def __init__(self, session: Session, host: str, polling_interval: int):
465+
super().__init__(session, host, "/api/2.0/pipelines", polling_interval, 60 * 60)
466+
467+
def poll_for_completion(self, pipeline_id: str) -> None:
468+
self._poll_api(
469+
url=f"/{pipeline_id}",
470+
params={},
471+
get_state_func=lambda response: response.json()["state"],
472+
terminal_states={"IDLE", "FAILED", "DELETED"},
473+
expected_end_state="IDLE",
474+
unexpected_end_state_func=self._get_exception,
475+
)
476+
477+
def _get_exception(self, response: Response) -> None:
478+
response_json = response.json()
479+
cause = response_json.get("cause")
480+
if cause:
481+
raise DbtRuntimeError(f"Pipeline {response_json.get('pipeline_id')} failed: {cause}")
482+
else:
483+
latest_update = response_json.get("latest_updates")[0]
484+
last_error = self.get_update_error(response_json.get("pipeline_id"), latest_update)
485+
raise DbtRuntimeError(
486+
f"Pipeline {response_json.get('pipeline_id')} failed: {last_error}"
487+
)
488+
489+
def get_update_error(self, pipeline_id: str, update_id: str) -> str:
490+
response = self.session.get(f"/{pipeline_id}/events")
491+
if response.status_code != 200:
492+
raise DbtRuntimeError(
493+
f"Error getting pipeline event info for {pipeline_id}: {response.text}"
494+
)
495+
496+
events = response.json().get("events", [])
497+
update_events = [
498+
e
499+
for e in events
500+
if e.get("event_type", "") == "update_progress"
501+
and e.get("origin", {}).get("update_id") == update_id
502+
]
503+
504+
error_events = [
505+
e
506+
for e in update_events
507+
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
508+
]
509+
510+
msg = ""
511+
if error_events:
512+
msg = error_events[0].get("message", "")
513+
514+
return msg
515+
516+
463517
class DatabricksApiClient:
464518
def __init__(
465519
self,
@@ -481,6 +535,7 @@ def __init__(
481535
self.job_runs = JobRunsApi(session, host, polling_interval, timeout)
482536
self.workflows = WorkflowJobApi(session, host)
483537
self.workflow_permissions = JobPermissionsApi(session, host)
538+
self.dlt_pipelines = DltPipelineApi(session, host, polling_interval)
484539

485540
@staticmethod
486541
def create(

dbt/adapters/databricks/connections.py

Lines changed: 6 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from dbt_common.events.functions import fire_event
1818
from dbt_common.exceptions import DbtDatabaseError, DbtInternalError, DbtRuntimeError
1919
from dbt_common.utils import cast_to_str
20-
from requests import Session
2120

2221
import databricks.sql as dbsql
2322
from databricks.sql.client import Connection as DatabricksSQLConnection
@@ -35,7 +34,6 @@
3534
)
3635
from dbt.adapters.databricks.__version__ import version as __version__
3736
from dbt.adapters.databricks.api_client import DatabricksApiClient
38-
from dbt.adapters.databricks.auth import BearerAuth
3937
from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider
4038
from dbt.adapters.databricks.events.connection_events import (
4139
ConnectionAcquire,
@@ -61,7 +59,6 @@
6159
CursorCreate,
6260
)
6361
from dbt.adapters.databricks.events.other_events import QueryError
64-
from dbt.adapters.databricks.events.pipeline_events import PipelineRefresh, PipelineRefreshError
6562
from dbt.adapters.databricks.logging import logger
6663
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
6764
from dbt.adapters.databricks.utils import redact_credentials
@@ -227,97 +224,6 @@ def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None:
227224
bindings = [self._fix_binding(binding) for binding in bindings]
228225
self._cursor.execute(sql, bindings)
229226

230-
def poll_refresh_pipeline(self, pipeline_id: str) -> None:
231-
# interval in seconds
232-
polling_interval = 10
233-
234-
# timeout in seconds
235-
timeout = 60 * 60
236-
237-
stopped_states = ("COMPLETED", "FAILED", "CANCELED")
238-
host: str = self._creds.host or ""
239-
headers = (
240-
self._cursor.connection.thrift_backend._auth_provider._header_factory # type: ignore
241-
)
242-
session = Session()
243-
session.auth = BearerAuth(headers)
244-
session.headers = {"User-Agent": self._user_agent}
245-
pipeline = _get_pipeline_state(session, host, pipeline_id)
246-
# get the most recently created update for the pipeline
247-
latest_update = _find_update(pipeline)
248-
if not latest_update:
249-
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")
250-
251-
state = latest_update.get("state")
252-
# we use update_id to retrieve the update in the polling loop
253-
update_id = latest_update.get("update_id", "")
254-
prev_state = state
255-
256-
logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))
257-
258-
start = time.time()
259-
exceeded_timeout = False
260-
while state not in stopped_states:
261-
if time.time() - start > timeout:
262-
exceeded_timeout = True
263-
break
264-
265-
# should we do exponential backoff?
266-
time.sleep(polling_interval)
267-
268-
pipeline = _get_pipeline_state(session, host, pipeline_id)
269-
# get the update we are currently polling
270-
update = _find_update(pipeline, update_id)
271-
if not update:
272-
raise DbtRuntimeError(
273-
f"Error getting pipeline update info: {pipeline_id}, update: {update_id}"
274-
)
275-
276-
state = update.get("state")
277-
if state != prev_state:
278-
logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))
279-
prev_state = state
280-
281-
if state == "FAILED":
282-
logger.error(
283-
PipelineRefreshError(
284-
pipeline_id,
285-
update_id,
286-
_get_update_error_msg(session, host, pipeline_id, update_id),
287-
)
288-
)
289-
290-
# another update may have been created due to retry_on_fail settings
291-
# get the latest update and see if it is a new one
292-
latest_update = _find_update(pipeline)
293-
if not latest_update:
294-
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")
295-
296-
latest_update_id = latest_update.get("update_id", "")
297-
if latest_update_id != update_id:
298-
update_id = latest_update_id
299-
state = None
300-
301-
if exceeded_timeout:
302-
raise DbtRuntimeError("timed out waiting for materialized view refresh")
303-
304-
if state == "FAILED":
305-
msg = _get_update_error_msg(session, host, pipeline_id, update_id)
306-
raise DbtRuntimeError(f"Error refreshing pipeline {pipeline_id} {msg}")
307-
308-
if state == "CANCELED":
309-
raise DbtRuntimeError(f"Refreshing pipeline {pipeline_id} cancelled")
310-
311-
return
312-
313-
@classmethod
314-
def findUpdate(cls, updates: list, id: str) -> Optional[dict]:
315-
matches = [x for x in updates if x.get("update_id") == id]
316-
if matches:
317-
return matches[0]
318-
319-
return None
320-
321227
@property
322228
def hex_query_id(self) -> str:
323229
"""Return the hex GUID for this query
@@ -475,12 +381,15 @@ class DatabricksConnectionManager(SparkConnectionManager):
475381
credentials_provider: Optional[TCredentialProvider] = None
476382
_user_agent = f"dbt-databricks/{__version__}"
477383

384+
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext):
385+
super().__init__(profile, mp_context)
386+
creds = cast(DatabricksCredentials, self.profile.credentials)
387+
self.api_client = DatabricksApiClient.create(creds, 15 * 60)
388+
478389
def cancel_open(self) -> list[str]:
479390
cancelled = super().cancel_open()
480-
creds = cast(DatabricksCredentials, self.profile.credentials)
481-
api_client = DatabricksApiClient.create(creds, 15 * 60)
482391
logger.info("Cancelling open python jobs")
483-
PythonRunTracker.cancel_runs(api_client)
392+
PythonRunTracker.cancel_runs(self.api_client)
484393
return cancelled
485394

486395
def compare_dbr_version(self, major: int, minor: int) -> int:
@@ -1079,60 +988,6 @@ def exponential_backoff(attempt: int) -> int:
1079988
)
1080989

1081990

1082-
def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict:
1083-
pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}"
1084-
1085-
response = session.get(pipeline_url)
1086-
if response.status_code != 200:
1087-
raise DbtRuntimeError(f"Error getting pipeline info for {pipeline_id}: {response.text}")
1088-
1089-
return response.json()
1090-
1091-
1092-
def _find_update(pipeline: dict, id: str = "") -> Optional[dict]:
1093-
updates = pipeline.get("latest_updates", [])
1094-
if not updates:
1095-
raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}")
1096-
1097-
if not id:
1098-
return updates[0]
1099-
1100-
matches = [x for x in updates if x.get("update_id") == id]
1101-
if matches:
1102-
return matches[0]
1103-
1104-
return None
1105-
1106-
1107-
def _get_update_error_msg(session: Session, host: str, pipeline_id: str, update_id: str) -> str:
1108-
events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events"
1109-
response = session.get(events_url)
1110-
if response.status_code != 200:
1111-
raise DbtRuntimeError(
1112-
f"Error getting pipeline event info for {pipeline_id}: {response.text}"
1113-
)
1114-
1115-
events = response.json().get("events", [])
1116-
update_events = [
1117-
e
1118-
for e in events
1119-
if e.get("event_type", "") == "update_progress"
1120-
and e.get("origin", {}).get("update_id") == update_id
1121-
]
1122-
1123-
error_events = [
1124-
e
1125-
for e in update_events
1126-
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
1127-
]
1128-
1129-
msg = ""
1130-
if error_events:
1131-
msg = error_events[0].get("message", "")
1132-
1133-
return msg
1134-
1135-
1136991
def _get_compute_name(query_header_context: Any) -> Optional[str]:
1137992
# Get the name of the specified compute resource from the node's
1138993
# config.

dbt/adapters/databricks/impl.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
from dbt.adapters.databricks.connections import (
3535
USE_LONG_SESSIONS,
3636
DatabricksConnectionManager,
37-
DatabricksDBTConnection,
38-
DatabricksSQLConnectionWrapper,
3937
ExtendedSessionConnectionManager,
4038
)
4139
from dbt.adapters.databricks.python_models.python_submissions import (
@@ -807,19 +805,13 @@ def get_from_relation(
807805
"""Get the relation config from the relation."""
808806

809807
relation_config = super(DeltaLiveTableAPIBase, cls).get_from_relation(adapter, relation)
810-
connection = cast(DatabricksDBTConnection, adapter.connections.get_thread_connection())
811-
wrapper: DatabricksSQLConnectionWrapper = connection.handle
812808

813809
# Ensure any current refreshes are completed before returning the relation config
814810
tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"])
815811
if tblproperties.pipeline_id:
816-
# TODO fix this path so that it doesn't need a cursor
817-
# It just calls APIs to poll the pipeline status
818-
cursor = wrapper.cursor()
819-
try:
820-
cursor.poll_refresh_pipeline(tblproperties.pipeline_id)
821-
finally:
822-
cursor.close()
812+
adapter.connections.api_client.dlt_pipelines.poll_for_completion(
813+
tblproperties.pipeline_id
814+
)
823815
return relation_config
824816

825817

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytest
2+
from dbt_common.exceptions import DbtRuntimeError
3+
4+
from dbt.adapters.databricks.api_client import DltPipelineApi
5+
from tests.unit.api_client.api_test_base import ApiTestBase
6+
7+
8+
class TestDltPipelineApi(ApiTestBase):
9+
@pytest.fixture
10+
def api(self, session, host):
11+
return DltPipelineApi(session, host, 1)
12+
13+
@pytest.fixture
14+
def pipeline_id(self):
15+
return "pipeline_id"
16+
17+
@pytest.fixture
18+
def update_id(self):
19+
return "update_id"
20+
21+
def test_get_update_error__non_200(self, api, session, pipeline_id, update_id):
22+
session.get.return_value.status_code = 500
23+
with pytest.raises(DbtRuntimeError):
24+
api.get_update_error(pipeline_id, update_id)
25+
26+
def test_get_update_error__200_no_events(self, api, session, pipeline_id, update_id):
27+
session.get.return_value.status_code = 200
28+
session.get.return_value.json.return_value = {"events": []}
29+
assert api.get_update_error(pipeline_id, update_id) == ""
30+
31+
def test_get_update_error__200_no_error_events(self, api, session, pipeline_id, update_id):
32+
session.get.return_value.status_code = 200
33+
session.get.return_value.json.return_value = {
34+
"events": [{"event_type": "update_progress", "origin": {"update_id": update_id}}]
35+
}
36+
assert api.get_update_error(pipeline_id, update_id) == ""
37+
38+
def test_get_update_error__200_error_events(self, api, session, pipeline_id, update_id):
39+
session.get.return_value.status_code = 200
40+
session.get.return_value.json.return_value = {
41+
"events": [
42+
{
43+
"message": "I failed",
44+
"details": {"update_progress": {"state": "FAILED"}},
45+
"event_type": "update_progress",
46+
"origin": {"update_id": update_id},
47+
}
48+
]
49+
}
50+
assert api.get_update_error(pipeline_id, update_id) == "I failed"
51+
52+
def test_poll_for_completion__non_200(self, api, session, pipeline_id):
53+
self.assert_non_200_raises_error(lambda: api.poll_for_completion(pipeline_id), session)
54+
55+
def test_poll_for_completion__200(self, api, session, host, pipeline_id):
56+
session.get.return_value.status_code = 200
57+
session.get.return_value.json.return_value = {"state": "IDLE"}
58+
api.poll_for_completion(pipeline_id)
59+
session.get.assert_called_once_with(
60+
f"https://{host}/api/2.0/pipelines/{pipeline_id}", json=None, params={}
61+
)
62+
63+
def test_poll_for_completion__failed_with_cause(self, api, session, pipeline_id):
64+
session.get.return_value.status_code = 200
65+
session.get.return_value.json.return_value = {
66+
"state": "FAILED",
67+
"pipeline_id": pipeline_id,
68+
"cause": "I failed",
69+
}
70+
with pytest.raises(DbtRuntimeError, match=f"Pipeline {pipeline_id} failed: I failed"):
71+
api.poll_for_completion(pipeline_id)

0 commit comments

Comments
 (0)