Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepnote_toolkit/runtime_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import psycopg2.extensions
import psycopg2.extras

from deepnote_toolkit.runtime_patches import apply_runtime_patches

from .dataframe_utils import add_formatters
from .execute_post_start_hooks import execute_post_start_hooks
from .logging import LoggerManager
Expand All @@ -24,6 +26,11 @@ def init_deepnote_runtime():

logger.debug("Initializing Deepnote runtime environment started.")

try:
apply_runtime_patches()
except Exception as e:
logger.error("Failed to apply runtime patches with a error: %s", e)

# Register sparksql magic
try:
IPython.get_ipython().register_magics(SparkSql)
Expand Down
53 changes: 53 additions & 0 deletions deepnote_toolkit/runtime_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Optional, Union

from deepnote_toolkit.logging import LoggerManager

logger = LoggerManager().get_logger()


# TODO(BLU-5171): Temporary hack to allow cancelling BigQuery jobs on KeyboardInterrupt (e.g. when user cancels cell execution)
# Can be removed once
# 1. https://github.com/googleapis/python-bigquery/pull/2331 is merged and released
# 2. Dependencies updated for the toolkit. We don't depend on google-cloud-bigquery directly, but it's transitive
# dependency through sqlalchemy-bigquery
def _monkeypatch_bigquery_wait_or_cancel():
try:
import google.cloud.bigquery._job_helpers as _job_helpers
from google.cloud.bigquery import job, table

def _wait_or_cancel(
job_obj: job.QueryJob,
api_timeout: Optional[float],
wait_timeout: Optional[Union[object, float]],
retry: Optional[Any],
page_size: Optional[int],
max_results: Optional[int],
) -> table.RowIterator:
try:
return job_obj.result(
page_size=page_size,
max_results=max_results,
retry=retry,
timeout=wait_timeout,
)
except (KeyboardInterrupt, Exception):
try:
job_obj.cancel(retry=retry, timeout=api_timeout)
except (KeyboardInterrupt, Exception):
pass
raise

_job_helpers._wait_or_cancel = _wait_or_cancel
logger.debug(
"Successfully monkeypatched google.cloud.bigquery._job_helpers._wait_or_cancel"
)
except ImportError:
logger.warning(
"Could not monkeypatch BigQuery _wait_or_cancel: google.cloud.bigquery not available"
)
except Exception as e:
logger.warning("Failed to monkeypatch BigQuery _wait_or_cancel: %s", repr(e))


def apply_runtime_patches():
_monkeypatch_bigquery_wait_or_cancel()
8 changes: 8 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
import pytest


@pytest.fixture(autouse=True, scope="session")
def apply_runtime_patches() -> None:
"""Apply runtime patches once before any tests run."""
from deepnote_toolkit.runtime_patches import apply_runtime_patches

apply_runtime_patches()


@pytest.fixture(autouse=True)
def clean_runtime_state() -> Generator[None, None, None]:
"""Automatically clean in-memory env state and config cache before and after each test."""
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_sql_execution_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,27 @@
from deepnote_toolkit.sql import sql_execution as se


def test_bigquery_wait_or_cancel_handles_keyboard_interrupt():
import google.cloud.bigquery._job_helpers as _job_helpers

mock_job = mock.Mock()
mock_job.result.side_effect = KeyboardInterrupt("User interrupted")
mock_job.cancel = mock.Mock()

with pytest.raises(KeyboardInterrupt):
# _wait_or_cancel should be monkeypatched by `_monkeypatch_bigquery_wait_or_cancel`
_job_helpers._wait_or_cancel(
job_obj=mock_job,
api_timeout=30.0,
wait_timeout=60.0,
retry=None,
page_size=None,
max_results=None,
)

mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0)


def test_build_params_for_bigquery_oauth_ok():
with mock.patch(
"deepnote_toolkit.sql.sql_execution.bigquery.Client"
Expand Down