|
| 1 | +from collections import defaultdict |
1 | 2 | from concurrent.futures import TimeoutError |
2 | 3 | import json |
3 | 4 | import re |
4 | 5 | from contextlib import contextmanager |
5 | 6 | from dataclasses import dataclass, field |
6 | | - |
7 | | -from dbt_common.invocation import get_invocation_id |
8 | | - |
9 | | -from dbt_common.events.contextvars import get_node_info |
| 7 | +import uuid |
10 | 8 | from mashumaro.helper import pass_through |
11 | 9 |
|
12 | 10 | from functools import lru_cache |
13 | 11 | from requests.exceptions import ConnectionError |
14 | | -from typing import Optional, Any, Dict, Tuple, TYPE_CHECKING |
| 12 | + |
| 13 | +from multiprocessing.context import SpawnContext |
| 14 | +from typing import Optional, Any, Dict, Tuple, Hashable, List, TYPE_CHECKING |
15 | 15 |
|
16 | 16 | import google.auth |
17 | 17 | import google.auth.exceptions |
|
24 | 24 | service_account as GoogleServiceAccountCredentials, |
25 | 25 | ) |
26 | 26 |
|
27 | | -from dbt.adapters.bigquery import gcloud |
28 | | -from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse, Credentials |
| 27 | +from dbt_common.events.contextvars import get_node_info |
| 28 | +from dbt_common.events.functions import fire_event |
29 | 29 | from dbt_common.exceptions import ( |
30 | 30 | DbtRuntimeError, |
31 | 31 | DbtConfigError, |
| 32 | + DbtDatabaseError, |
| 33 | +) |
| 34 | +from dbt_common.invocation import get_invocation_id |
| 35 | +from dbt.adapters.bigquery import gcloud |
| 36 | +from dbt.adapters.contracts.connection import ( |
| 37 | + ConnectionState, |
| 38 | + AdapterResponse, |
| 39 | + Credentials, |
| 40 | + AdapterRequiredConfig, |
32 | 41 | ) |
33 | | - |
34 | | -from dbt_common.exceptions import DbtDatabaseError |
35 | 42 | from dbt.adapters.exceptions.connection import FailedToConnectError |
36 | 43 | from dbt.adapters.base import BaseConnectionManager |
37 | 44 | from dbt.adapters.events.logging import AdapterLogger |
38 | 45 | from dbt.adapters.events.types import SQLQuery |
39 | | -from dbt_common.events.functions import fire_event |
40 | 46 | from dbt.adapters.bigquery import __version__ as dbt_version |
41 | 47 | from dbt.adapters.bigquery.utility import is_base64, base64_to_string |
42 | 48 |
|
@@ -231,6 +237,10 @@ class BigQueryConnectionManager(BaseConnectionManager): |
231 | 237 | DEFAULT_INITIAL_DELAY = 1.0 # Seconds |
232 | 238 | DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds |
233 | 239 |
|
| 240 | + def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): |
| 241 | + super().__init__(profile, mp_context) |
| 242 | + self.jobs_by_thread: Dict[Hashable, List[str]] = defaultdict(list) |
| 243 | + |
234 | 244 | @classmethod |
235 | 245 | def handle_error(cls, error, message): |
236 | 246 | error_msg = "\n".join([item["message"] for item in error.errors]) |
@@ -284,11 +294,31 @@ def exception_handler(self, sql): |
284 | 294 | exc_message = exc_message.split(BQ_QUERY_JOB_SPLIT)[0].strip() |
285 | 295 | raise DbtRuntimeError(exc_message) |
286 | 296 |
|
287 | | - def cancel_open(self) -> None: |
288 | | - pass |
| 297 | + def cancel_open(self): |
| 298 | + names = [] |
| 299 | + this_connection = self.get_if_exists() |
| 300 | + with self.lock: |
| 301 | + for thread_id, connection in self.thread_connections.items(): |
| 302 | + if connection is this_connection: |
| 303 | + continue |
| 304 | + if connection.handle is not None and connection.state == ConnectionState.OPEN: |
| 305 | + client = connection.handle |
| 306 | + for job_id in self.jobs_by_thread.get(thread_id, []): |
| 307 | + |
| 308 | + def fn(): |
| 309 | + return client.cancel_job(job_id) |
| 310 | + |
| 311 | + self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn) |
| 312 | + |
| 313 | + self.close(connection) |
| 314 | + |
| 315 | + if connection.name is not None: |
| 316 | + names.append(connection.name) |
| 317 | + return names |
289 | 318 |
|
290 | 319 | @classmethod |
291 | 320 | def close(cls, connection): |
| 321 | + connection.handle.close() |
292 | 322 | connection.state = ConnectionState.CLOSED |
293 | 323 |
|
294 | 324 | return connection |
@@ -452,6 +482,18 @@ def get_labels_from_query_comment(cls): |
452 | 482 |
|
453 | 483 | return {} |
454 | 484 |
|
| 485 | + def generate_job_id(self) -> str: |
| 486 | + # Generating a fresh job_id for every _query_and_results call to avoid job_id reuse. |
| 487 | + # Generating a job id instead of persisting a BigQuery-generated one after client.query is called. |
| 488 | + # Using BigQuery's job_id can lead to a race condition if a job has been started and a termination |
| 489 | + # is sent before the job_id was stored, leading to a failure to cancel the job. |
| 490 | + # By predetermining job_ids (uuid4), we can persist the job_id before the job has been kicked off. |
| 491 | + # Doing this, the race condition only leads to attempting to cancel a job that doesn't exist. |
| 492 | + job_id = str(uuid.uuid4()) |
| 493 | + thread_id = self.get_thread_identifier() |
| 494 | + self.jobs_by_thread[thread_id].append(job_id) |
| 495 | + return job_id |
| 496 | + |
455 | 497 | def raw_execute( |
456 | 498 | self, |
457 | 499 | sql, |
@@ -488,10 +530,13 @@ def raw_execute( |
488 | 530 | job_execution_timeout = self.get_job_execution_timeout_seconds(conn) |
489 | 531 |
|
490 | 532 | def fn(): |
| 533 | + job_id = self.generate_job_id() |
| 534 | + |
491 | 535 | return self._query_and_results( |
492 | 536 | client, |
493 | 537 | sql, |
494 | 538 | job_params, |
| 539 | + job_id, |
495 | 540 | job_creation_timeout=job_creation_timeout, |
496 | 541 | job_execution_timeout=job_execution_timeout, |
497 | 542 | limit=limit, |
@@ -731,14 +776,17 @@ def _query_and_results( |
731 | 776 | client, |
732 | 777 | sql, |
733 | 778 | job_params, |
| 779 | + job_id, |
734 | 780 | job_creation_timeout=None, |
735 | 781 | job_execution_timeout=None, |
736 | 782 | limit: Optional[int] = None, |
737 | 783 | ): |
738 | 784 | """Query the client and wait for results.""" |
739 | 785 | # Cannot reuse job_config if destination is set and ddl is used |
740 | 786 | job_config = google.cloud.bigquery.QueryJobConfig(**job_params) |
741 | | - query_job = client.query(query=sql, job_config=job_config, timeout=job_creation_timeout) |
| 787 | + query_job = client.query( |
| 788 | + query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout |
| 789 | + ) |
742 | 790 | if ( |
743 | 791 | query_job.location is not None |
744 | 792 | and query_job.job_id is not None |
|
0 commit comments