|
17 | 17 | from dbt_common.events.functions import fire_event |
18 | 18 | from dbt_common.exceptions import DbtDatabaseError, DbtInternalError, DbtRuntimeError |
19 | 19 | from dbt_common.utils import cast_to_str |
20 | | -from requests import Session |
21 | 20 |
|
22 | 21 | import databricks.sql as dbsql |
23 | 22 | from databricks.sql.client import Connection as DatabricksSQLConnection |
|
35 | 34 | ) |
36 | 35 | from dbt.adapters.databricks.__version__ import version as __version__ |
37 | 36 | from dbt.adapters.databricks.api_client import DatabricksApiClient |
38 | | -from dbt.adapters.databricks.auth import BearerAuth |
39 | 37 | from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider |
40 | 38 | from dbt.adapters.databricks.events.connection_events import ( |
41 | 39 | ConnectionAcquire, |
|
61 | 59 | CursorCreate, |
62 | 60 | ) |
63 | 61 | from dbt.adapters.databricks.events.other_events import QueryError |
64 | | -from dbt.adapters.databricks.events.pipeline_events import PipelineRefresh, PipelineRefreshError |
65 | 62 | from dbt.adapters.databricks.logging import logger |
66 | 63 | from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker |
67 | 64 | from dbt.adapters.databricks.utils import redact_credentials |
@@ -227,97 +224,6 @@ def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None: |
227 | 224 | bindings = [self._fix_binding(binding) for binding in bindings] |
228 | 225 | self._cursor.execute(sql, bindings) |
229 | 226 |
|
230 | | - def poll_refresh_pipeline(self, pipeline_id: str) -> None: |
231 | | - # interval in seconds |
232 | | - polling_interval = 10 |
233 | | - |
234 | | - # timeout in seconds |
235 | | - timeout = 60 * 60 |
236 | | - |
237 | | - stopped_states = ("COMPLETED", "FAILED", "CANCELED") |
238 | | - host: str = self._creds.host or "" |
239 | | - headers = ( |
240 | | - self._cursor.connection.thrift_backend._auth_provider._header_factory # type: ignore |
241 | | - ) |
242 | | - session = Session() |
243 | | - session.auth = BearerAuth(headers) |
244 | | - session.headers = {"User-Agent": self._user_agent} |
245 | | - pipeline = _get_pipeline_state(session, host, pipeline_id) |
246 | | - # get the most recently created update for the pipeline |
247 | | - latest_update = _find_update(pipeline) |
248 | | - if not latest_update: |
249 | | - raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}") |
250 | | - |
251 | | - state = latest_update.get("state") |
252 | | - # we use update_id to retrieve the update in the polling loop |
253 | | - update_id = latest_update.get("update_id", "") |
254 | | - prev_state = state |
255 | | - |
256 | | - logger.info(PipelineRefresh(pipeline_id, update_id, str(state))) |
257 | | - |
258 | | - start = time.time() |
259 | | - exceeded_timeout = False |
260 | | - while state not in stopped_states: |
261 | | - if time.time() - start > timeout: |
262 | | - exceeded_timeout = True |
263 | | - break |
264 | | - |
265 | | - # should we do exponential backoff? |
266 | | - time.sleep(polling_interval) |
267 | | - |
268 | | - pipeline = _get_pipeline_state(session, host, pipeline_id) |
269 | | - # get the update we are currently polling |
270 | | - update = _find_update(pipeline, update_id) |
271 | | - if not update: |
272 | | - raise DbtRuntimeError( |
273 | | - f"Error getting pipeline update info: {pipeline_id}, update: {update_id}" |
274 | | - ) |
275 | | - |
276 | | - state = update.get("state") |
277 | | - if state != prev_state: |
278 | | - logger.info(PipelineRefresh(pipeline_id, update_id, str(state))) |
279 | | - prev_state = state |
280 | | - |
281 | | - if state == "FAILED": |
282 | | - logger.error( |
283 | | - PipelineRefreshError( |
284 | | - pipeline_id, |
285 | | - update_id, |
286 | | - _get_update_error_msg(session, host, pipeline_id, update_id), |
287 | | - ) |
288 | | - ) |
289 | | - |
290 | | - # another update may have been created due to retry_on_fail settings |
291 | | - # get the latest update and see if it is a new one |
292 | | - latest_update = _find_update(pipeline) |
293 | | - if not latest_update: |
294 | | - raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}") |
295 | | - |
296 | | - latest_update_id = latest_update.get("update_id", "") |
297 | | - if latest_update_id != update_id: |
298 | | - update_id = latest_update_id |
299 | | - state = None |
300 | | - |
301 | | - if exceeded_timeout: |
302 | | - raise DbtRuntimeError("timed out waiting for materialized view refresh") |
303 | | - |
304 | | - if state == "FAILED": |
305 | | - msg = _get_update_error_msg(session, host, pipeline_id, update_id) |
306 | | - raise DbtRuntimeError(f"Error refreshing pipeline {pipeline_id} {msg}") |
307 | | - |
308 | | - if state == "CANCELED": |
309 | | - raise DbtRuntimeError(f"Refreshing pipeline {pipeline_id} cancelled") |
310 | | - |
311 | | - return |
312 | | - |
313 | | - @classmethod |
314 | | - def findUpdate(cls, updates: list, id: str) -> Optional[dict]: |
315 | | - matches = [x for x in updates if x.get("update_id") == id] |
316 | | - if matches: |
317 | | - return matches[0] |
318 | | - |
319 | | - return None |
320 | | - |
321 | 227 | @property |
322 | 228 | def hex_query_id(self) -> str: |
323 | 229 | """Return the hex GUID for this query |
@@ -475,12 +381,15 @@ class DatabricksConnectionManager(SparkConnectionManager): |
475 | 381 | credentials_provider: Optional[TCredentialProvider] = None |
476 | 382 | _user_agent = f"dbt-databricks/{__version__}" |
477 | 383 |
|
| 384 | + def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): |
| 385 | + super().__init__(profile, mp_context) |
| 386 | + creds = cast(DatabricksCredentials, self.profile.credentials) |
| 387 | + self.api_client = DatabricksApiClient.create(creds, 15 * 60) |
| 388 | + |
478 | 389 | def cancel_open(self) -> list[str]: |
479 | 390 | cancelled = super().cancel_open() |
480 | | - creds = cast(DatabricksCredentials, self.profile.credentials) |
481 | | - api_client = DatabricksApiClient.create(creds, 15 * 60) |
482 | 391 | logger.info("Cancelling open python jobs") |
483 | | - PythonRunTracker.cancel_runs(api_client) |
| 392 | + PythonRunTracker.cancel_runs(self.api_client) |
484 | 393 | return cancelled |
485 | 394 |
|
486 | 395 | def compare_dbr_version(self, major: int, minor: int) -> int: |
@@ -1079,60 +988,6 @@ def exponential_backoff(attempt: int) -> int: |
1079 | 988 | ) |
1080 | 989 |
|
1081 | 990 |
|
1082 | | -def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict: |
1083 | | - pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}" |
1084 | | - |
1085 | | - response = session.get(pipeline_url) |
1086 | | - if response.status_code != 200: |
1087 | | - raise DbtRuntimeError(f"Error getting pipeline info for {pipeline_id}: {response.text}") |
1088 | | - |
1089 | | - return response.json() |
1090 | | - |
1091 | | - |
1092 | | -def _find_update(pipeline: dict, id: str = "") -> Optional[dict]: |
1093 | | - updates = pipeline.get("latest_updates", []) |
1094 | | - if not updates: |
1095 | | - raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}") |
1096 | | - |
1097 | | - if not id: |
1098 | | - return updates[0] |
1099 | | - |
1100 | | - matches = [x for x in updates if x.get("update_id") == id] |
1101 | | - if matches: |
1102 | | - return matches[0] |
1103 | | - |
1104 | | - return None |
1105 | | - |
1106 | | - |
1107 | | -def _get_update_error_msg(session: Session, host: str, pipeline_id: str, update_id: str) -> str: |
1108 | | - events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events" |
1109 | | - response = session.get(events_url) |
1110 | | - if response.status_code != 200: |
1111 | | - raise DbtRuntimeError( |
1112 | | - f"Error getting pipeline event info for {pipeline_id}: {response.text}" |
1113 | | - ) |
1114 | | - |
1115 | | - events = response.json().get("events", []) |
1116 | | - update_events = [ |
1117 | | - e |
1118 | | - for e in events |
1119 | | - if e.get("event_type", "") == "update_progress" |
1120 | | - and e.get("origin", {}).get("update_id") == update_id |
1121 | | - ] |
1122 | | - |
1123 | | - error_events = [ |
1124 | | - e |
1125 | | - for e in update_events |
1126 | | - if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED" |
1127 | | - ] |
1128 | | - |
1129 | | - msg = "" |
1130 | | - if error_events: |
1131 | | - msg = error_events[0].get("message", "") |
1132 | | - |
1133 | | - return msg |
1134 | | - |
1135 | | - |
1136 | 991 | def _get_compute_name(query_header_context: Any) -> Optional[str]: |
1137 | 992 | # Get the name of the specified compute resource from the node's |
1138 | 993 | # config. |
|
0 commit comments