Skip to content
This repository was archived by the owner on Sep 2, 2025. It is now read-only.

Commit 3839953

Browse files
MichelleArkd-colejtcohen6colin-rogers-dbt
authored
Add cancel (#1251)
* Add query cancellation. * clean up merge + linting * Add back mp_context * generating a fresh job_id for every _query_and_results call * add cancellation test * add cancellation test * add seed cancellation * remove type ignore * add changie * use defaultdict to simplify code --------- Co-authored-by: Daniel Cole <[email protected]> Co-authored-by: Jeremy Cohen <[email protected]> Co-authored-by: Colin Rogers <[email protected]> Co-authored-by: Colin <[email protected]>
1 parent 8c0a192 commit 3839953

File tree

6 files changed

+213
-29
lines changed

6 files changed

+213
-29
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
kind: Features
2+
body: Add support for cancelling queries on keyboard interrupt
3+
time: 2024-07-30T13:59:11.585452-07:00
4+
custom:
5+
Author: d-cole MichelleArk colin-rogers-dbt
6+
Issue: "917"

dbt/adapters/bigquery/connections.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1+
from collections import defaultdict
12
from concurrent.futures import TimeoutError
23
import json
34
import re
45
from contextlib import contextmanager
56
from dataclasses import dataclass, field
6-
7-
from dbt_common.invocation import get_invocation_id
8-
9-
from dbt_common.events.contextvars import get_node_info
7+
import uuid
108
from mashumaro.helper import pass_through
119

1210
from functools import lru_cache
1311
from requests.exceptions import ConnectionError
14-
from typing import Optional, Any, Dict, Tuple, TYPE_CHECKING
12+
13+
from multiprocessing.context import SpawnContext
14+
from typing import Optional, Any, Dict, Tuple, Hashable, List, TYPE_CHECKING
1515

1616
import google.auth
1717
import google.auth.exceptions
@@ -24,19 +24,25 @@
2424
service_account as GoogleServiceAccountCredentials,
2525
)
2626

27-
from dbt.adapters.bigquery import gcloud
28-
from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse, Credentials
27+
from dbt_common.events.contextvars import get_node_info
28+
from dbt_common.events.functions import fire_event
2929
from dbt_common.exceptions import (
3030
DbtRuntimeError,
3131
DbtConfigError,
32+
DbtDatabaseError,
33+
)
34+
from dbt_common.invocation import get_invocation_id
35+
from dbt.adapters.bigquery import gcloud
36+
from dbt.adapters.contracts.connection import (
37+
ConnectionState,
38+
AdapterResponse,
39+
Credentials,
40+
AdapterRequiredConfig,
3241
)
33-
34-
from dbt_common.exceptions import DbtDatabaseError
3542
from dbt.adapters.exceptions.connection import FailedToConnectError
3643
from dbt.adapters.base import BaseConnectionManager
3744
from dbt.adapters.events.logging import AdapterLogger
3845
from dbt.adapters.events.types import SQLQuery
39-
from dbt_common.events.functions import fire_event
4046
from dbt.adapters.bigquery import __version__ as dbt_version
4147
from dbt.adapters.bigquery.utility import is_base64, base64_to_string
4248

@@ -231,6 +237,10 @@ class BigQueryConnectionManager(BaseConnectionManager):
231237
DEFAULT_INITIAL_DELAY = 1.0 # Seconds
232238
DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds
233239

240+
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext):
241+
super().__init__(profile, mp_context)
242+
self.jobs_by_thread: Dict[Hashable, List[str]] = defaultdict(list)
243+
234244
@classmethod
235245
def handle_error(cls, error, message):
236246
error_msg = "\n".join([item["message"] for item in error.errors])
@@ -284,11 +294,31 @@ def exception_handler(self, sql):
284294
exc_message = exc_message.split(BQ_QUERY_JOB_SPLIT)[0].strip()
285295
raise DbtRuntimeError(exc_message)
286296

287-
def cancel_open(self) -> None:
288-
pass
297+
def cancel_open(self):
298+
names = []
299+
this_connection = self.get_if_exists()
300+
with self.lock:
301+
for thread_id, connection in self.thread_connections.items():
302+
if connection is this_connection:
303+
continue
304+
if connection.handle is not None and connection.state == ConnectionState.OPEN:
305+
client = connection.handle
306+
for job_id in self.jobs_by_thread.get(thread_id, []):
307+
308+
def fn():
309+
return client.cancel_job(job_id)
310+
311+
self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn)
312+
313+
self.close(connection)
314+
315+
if connection.name is not None:
316+
names.append(connection.name)
317+
return names
289318

290319
@classmethod
291320
def close(cls, connection):
321+
connection.handle.close()
292322
connection.state = ConnectionState.CLOSED
293323

294324
return connection
@@ -452,6 +482,18 @@ def get_labels_from_query_comment(cls):
452482

453483
return {}
454484

485+
def generate_job_id(self) -> str:
486+
# Generating a fresh job_id for every _query_and_results call to avoid job_id reuse.
487+
# Generating a job id instead of persisting a BigQuery-generated one after client.query is called.
488+
# Using BigQuery's job_id can lead to a race condition if a job has been started and a termination
489+
# is sent before the job_id was stored, leading to a failure to cancel the job.
490+
# By predetermining job_ids (uuid4), we can persist the job_id before the job has been kicked off.
491+
# Doing this, the race condition only leads to attempting to cancel a job that doesn't exist.
492+
job_id = str(uuid.uuid4())
493+
thread_id = self.get_thread_identifier()
494+
self.jobs_by_thread[thread_id].append(job_id)
495+
return job_id
496+
455497
def raw_execute(
456498
self,
457499
sql,
@@ -488,10 +530,13 @@ def raw_execute(
488530
job_execution_timeout = self.get_job_execution_timeout_seconds(conn)
489531

490532
def fn():
533+
job_id = self.generate_job_id()
534+
491535
return self._query_and_results(
492536
client,
493537
sql,
494538
job_params,
539+
job_id,
495540
job_creation_timeout=job_creation_timeout,
496541
job_execution_timeout=job_execution_timeout,
497542
limit=limit,
@@ -731,14 +776,17 @@ def _query_and_results(
731776
client,
732777
sql,
733778
job_params,
779+
job_id,
734780
job_creation_timeout=None,
735781
job_execution_timeout=None,
736782
limit: Optional[int] = None,
737783
):
738784
"""Query the client and wait for results."""
739785
# Cannot reuse job_config if destination is set and ddl is used
740786
job_config = google.cloud.bigquery.QueryJobConfig(**job_params)
741-
query_job = client.query(query=sql, job_config=job_config, timeout=job_creation_timeout)
787+
query_job = client.query(
788+
query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout
789+
)
742790
if (
743791
query_job.location is not None
744792
and query_job.job_id is not None

dbt/adapters/bigquery/impl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def date_function(cls) -> str:
164164

165165
@classmethod
166166
def is_cancelable(cls) -> bool:
167-
return False
167+
return True
168168

169169
def drop_relation(self, relation: BigQueryRelation) -> None:
170170
is_cached = self._schema_is_cached(relation.database, relation.schema)
@@ -693,8 +693,11 @@ def load_dataframe(
693693
load_config.skip_leading_rows = 1
694694
load_config.schema = bq_schema
695695
load_config.field_delimiter = field_delimiter
696+
job_id = self.connections.generate_job_id()
696697
with open(agate_table.original_abspath, "rb") as f: # type: ignore
697-
job = client.load_table_from_file(f, table_ref, rewind=True, job_config=load_config)
698+
job = client.load_table_from_file(
699+
f, table_ref, rewind=True, job_config=load_config, job_id=job_id
700+
)
698701

699702
timeout = self.connections.get_job_execution_timeout_seconds(conn) or 300
700703
with self.connections.exception_handler("LOAD TABLE"):

tests/functional/test_cancel.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import time
2+
3+
import os
4+
import signal
5+
import subprocess
6+
7+
import pytest
8+
9+
from dbt.tests.util import get_connection
10+
11+
_SEED_CSV = """
12+
id, name, astrological_sign, moral_alignment
13+
1, Alice, Aries, Lawful Good
14+
2, Bob, Taurus, Neutral Good
15+
3, Thaddeus, Gemini, Chaotic Neutral
16+
4, Zebulon, Cancer, Lawful Evil
17+
5, Yorick, Leo, True Neutral
18+
6, Xavier, Virgo, Chaotic Evil
19+
7, Wanda, Libra, Lawful Neutral
20+
"""
21+
22+
_LONG_RUNNING_MODEL_SQL = """
23+
{{ config(materialized='table') }}
24+
with array_1 as (
25+
select generated_ids from UNNEST(GENERATE_ARRAY(1, 200000)) AS generated_ids
26+
),
27+
array_2 as (
28+
select generated_ids from UNNEST(GENERATE_ARRAY(2, 200000)) AS generated_ids
29+
)
30+
31+
SELECT array_1.generated_ids
32+
FROM array_1
33+
LEFT JOIN array_1 as jnd on 1=1
34+
LEFT JOIN array_2 as jnd2 on 1=1
35+
LEFT JOIN array_1 as jnd3 on jnd3.generated_ids >= jnd2.generated_ids
36+
"""
37+
38+
39+
def _get_info_schema_jobs_query(project_id, dataset_id, table_id):
40+
"""
41+
Running this query requires roles/bigquery.resourceViewer on the project,
42+
see: https://cloud.google.com/bigquery/docs/information-schema-jobs#required_role
43+
:param project_id:
44+
:param dataset_id:
45+
:param table_id:
46+
:return: a single job id that matches the model we tried to create and was cancelled
47+
"""
48+
return f"""
49+
SELECT job_id
50+
FROM `region-us`.`INFORMATION_SCHEMA.JOBS_BY_PROJECT`
51+
WHERE creation_time > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 5 HOUR)
52+
AND statement_type = 'CREATE_TABLE_AS_SELECT'
53+
AND state = 'DONE'
54+
AND job_id IS NOT NULL
55+
AND project_id = '{project_id}'
56+
AND error_result.reason = 'stopped'
57+
AND error_result.message = 'Job execution was cancelled: User requested cancellation'
58+
AND destination_table.table_id = '{table_id}'
59+
AND destination_table.dataset_id = '{dataset_id}'
60+
"""
61+
62+
63+
def _run_dbt_in_subprocess(project, dbt_command):
64+
os.chdir(project.project_root)
65+
run_dbt_process = subprocess.Popen(
66+
[
67+
"dbt",
68+
dbt_command,
69+
"--profiles-dir",
70+
project.profiles_dir,
71+
"--project-dir",
72+
project.project_root,
73+
],
74+
stdout=subprocess.PIPE,
75+
stderr=subprocess.PIPE,
76+
shell=False,
77+
)
78+
std_out_log = ""
79+
while True:
80+
std_out_line = run_dbt_process.stdout.readline().decode("utf-8")
81+
std_out_log += std_out_line
82+
if std_out_line != "":
83+
print(std_out_line)
84+
if "1 of 1 START" in std_out_line:
85+
time.sleep(1)
86+
run_dbt_process.send_signal(signal.SIGINT)
87+
88+
if run_dbt_process.poll():
89+
break
90+
91+
return std_out_log
92+
93+
94+
def _get_job_id(project, table_name):
95+
# Because we run this in a subprocess we have to actually call Bigquery and look up the job id
96+
with get_connection(project.adapter):
97+
job_id = project.run_sql(
98+
_get_info_schema_jobs_query(project.database, project.test_schema, table_name)
99+
)
100+
101+
return job_id
102+
103+
104+
class TestBigqueryCancelsQueriesOnKeyboardInterrupt:
105+
@pytest.fixture(scope="class", autouse=True)
106+
def models(self):
107+
return {
108+
"model.sql": _LONG_RUNNING_MODEL_SQL,
109+
}
110+
111+
@pytest.fixture(scope="class", autouse=True)
112+
def seeds(self):
113+
return {
114+
"seed.csv": _SEED_CSV,
115+
}
116+
117+
def test_bigquery_cancels_queries_for_model_on_keyboard_interrupt(self, project):
118+
std_out_log = _run_dbt_in_subprocess(project, "run")
119+
120+
assert "CANCEL query model.test.model" in std_out_log
121+
assert len(_get_job_id(project, "model")) == 1
122+
123+
def test_bigquery_cancels_queries_for_seed_on_keyboard_interrupt(self, project):
124+
std_out_log = _run_dbt_in_subprocess(project, "seed")
125+
126+
assert "CANCEL query seed.test.seed" in std_out_log
127+
# we can't assert the job id since we can't kill the seed process fast enough to cancel it

tests/unit/test_bigquery_adapter.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
inject_adapter,
3333
TestAdapterConversions,
3434
load_internal_manifest_macros,
35+
mock_connection,
3536
)
3637

3738

@@ -368,23 +369,22 @@ def test_acquire_connection_maximum_bytes_billed(self, mock_open_connection):
368369

369370
def test_cancel_open_connections_empty(self):
370371
adapter = self.get_adapter("oauth")
371-
self.assertEqual(adapter.cancel_open_connections(), None)
372+
self.assertEqual(len(list(adapter.cancel_open_connections())), 0)
372373

373374
def test_cancel_open_connections_master(self):
374375
adapter = self.get_adapter("oauth")
375-
adapter.connections.thread_connections[0] = object()
376-
self.assertEqual(adapter.cancel_open_connections(), None)
376+
key = adapter.connections.get_thread_identifier()
377+
adapter.connections.thread_connections[key] = mock_connection("master")
378+
self.assertEqual(len(list(adapter.cancel_open_connections())), 0)
377379

378380
def test_cancel_open_connections_single(self):
379381
adapter = self.get_adapter("oauth")
380-
adapter.connections.thread_connections.update(
381-
{
382-
0: object(),
383-
1: object(),
384-
}
385-
)
386-
# actually does nothing
387-
self.assertEqual(adapter.cancel_open_connections(), None)
382+
master = mock_connection("master")
383+
model = mock_connection("model")
384+
key = adapter.connections.get_thread_identifier()
385+
386+
adapter.connections.thread_connections.update({key: master, 1: model})
387+
self.assertEqual(len(list(adapter.cancel_open_connections())), 1)
388388

389389
@patch("dbt.adapters.bigquery.impl.google.auth.default")
390390
@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")

tests/unit/test_bigquery_connection_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,18 @@ def test_drop_dataset(self):
104104

105105
@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
106106
def test_query_and_results(self, mock_bq):
107-
self.mock_client.query = Mock(return_value=Mock(state="DONE"))
108107
self.connections._query_and_results(
109108
self.mock_client,
110109
"sql",
111110
{"job_param_1": "blah"},
111+
job_id=1,
112112
job_creation_timeout=15,
113-
job_execution_timeout=3,
113+
job_execution_timeout=100,
114114
)
115115

116116
mock_bq.QueryJobConfig.assert_called_once()
117117
self.mock_client.query.assert_called_once_with(
118-
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
118+
query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15
119119
)
120120

121121
def test_copy_bq_table_appends(self):

0 commit comments

Comments
 (0)