diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py index 52378224..ef1e6066 100644 --- a/src/databricks/labs/lsql/backends.py +++ b/src/databricks/labs/lsql/backends.py @@ -21,7 +21,7 @@ from databricks.sdk.retries import retried from databricks.sdk.service.compute import Language -from databricks.labs.lsql.core import Row, StatementExecutionExt +from databricks.labs.lsql.core import DeltaConcurrentAppend, Row, StatementExecutionExt from databricks.labs.lsql.structs import StructInference logger = logging.getLogger(__name__) @@ -117,6 +117,8 @@ def _api_error_from_message(error_message: str) -> DatabricksError: return BadRequest(error_message) if "Operation not allowed" in error_message: return PermissionDenied(error_message) + if "DELTA_CONCURRENT_APPEND" in error_message: + return DeltaConcurrentAppend(error_message) return Unknown(error_message) diff --git a/src/databricks/labs/lsql/core.py b/src/databricks/labs/lsql/core.py index 856315b1..2ac80df6 100644 --- a/src/databricks/labs/lsql/core.py +++ b/src/databricks/labs/lsql/core.py @@ -13,7 +13,8 @@ import requests import sqlglot from databricks.sdk import WorkspaceClient, errors -from databricks.sdk.errors import DataLoss, NotFound +from databricks.sdk.errors import BadRequest, DatabricksError, DataLoss, NotFound +from databricks.sdk.retries import retried from databricks.sdk.service.sql import ( ColumnInfoTypeName, Disposition, @@ -119,6 +120,10 @@ def __repr__(self): return f"Row({', '.join(f'{k}={v!r}' for (k, v) in zip(self.__columns__, self, strict=True))})" +class DeltaConcurrentAppend(DatabricksError): + """Error raised when appending concurrent to a Delta table.""" + + class StatementExecutionExt: """Execute SQL statements in a stateless manner. @@ -182,6 +187,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument ColumnInfoTypeName.TIMESTAMP: self._parse_timestamp, } + @retried(on=[DeltaConcurrentAppend], timeout=timedelta(seconds=10)) def execute( self, statement: str, @@ -469,6 +475,8 @@ def _raise_if_needed(status: StatementStatus): raise NotFound(error_message) if "DELTA_MISSING_TRANSACTION_LOG" in error_message: raise DataLoss(error_message) + if "DELTA_CONCURRENT_APPEND" in error_message: + raise DeltaConcurrentAppend(error_message) mapping = { ServiceErrorCode.ABORTED: errors.Aborted, ServiceErrorCode.ALREADY_EXISTS: errors.AlreadyExists, diff --git a/tests/integration/test_backends.py b/tests/integration/test_backends.py index 6930aec2..ad853c8c 100644 --- a/tests/integration/test_backends.py +++ b/tests/integration/test_backends.py @@ -1,7 +1,10 @@ import pytest from databricks.labs.blueprint.commands import CommandExecutor from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.parallel import Threads from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 +from databricks.sdk.errors import BadRequest +from databricks.sdk.service import compute from databricks.labs.lsql import Row from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend @@ -74,7 +77,6 @@ return "PASSED" """ - UNKNOWN_ERROR = """ from databricks.labs.lsql.backends import RuntimeBackend from databricks.sdk.errors import Unknown @@ -86,6 +88,37 @@ return "PASSED" """ +CONCURRENT_APPEND = ''' +import math +import time + + +def wait_until_seconds_rollover(*, rollover_seconds: int = 10) -> None: + """Wait until the next rollover. + + Useful to align concurrent writes. + + Args: + rollover_seconds (int) : The multiple of seconds to wait until the next rollover. + """ + nano, micro = 1e9, 1e6 + + nanoseconds_now = time.clock_gettime_ns(time.CLOCK_BOOTTIME) + nanoseconds_target = math.ceil(nanoseconds_now / nano // rollover_seconds) * nano * rollover_seconds + + # To hit the rollover more accurate, first sleep until almost target + nanoseconds_until_almost_target = (nanoseconds_target - nanoseconds_now) - micro + time.sleep(max(nanoseconds_until_almost_target / nano, 0)) + + # Then busy-wait until the rollover occurs + while time.clock_gettime_ns(time.CLOCK_BOOTTIME) < nanoseconds_target: + pass + + +wait_until_seconds_rollover() +spark.sql("UPDATE {table_full_name} SET y = y * 2 WHERE (x % 2 = 0)") +''' + @pytest.mark.xfail def test_runtime_backend_works_maps_permission_denied(ws): @@ -139,6 +172,27 @@ def test_runtime_backend_errors_handled(ws, query): assert result == "PASSED" +def test_runtime_backend_handles_concurrent_append(ws, make_random, make_table) -> None: + commands = CommandExecutor( + ws.clusters, + ws.command_execution, + lambda: ws.config.cluster_id, + language=compute.Language.PYTHON, + ) + table = make_table(name=f"lsql_test_{make_random()}", ctas="SELECT r.id AS x, random() AS y FROM range(1000000) r") + + def update_table() -> None: + commands.run(CONCURRENT_APPEND.format(table_full_name=table.full_name)) + + try: + Threads.strict("concurrent appends", [update_table, update_table]) + except BadRequest as e: + if "DELTA_CONCURRENT_APPEND" in str(e): + assert False, str(e) + else: + raise # Raise in case of unexpected error + + def test_statement_execution_backend_works(ws, env_or_skip): sql_backend = StatementExecutionBackend(ws, env_or_skip("TEST_DEFAULT_WAREHOUSE_ID")) rows = list(sql_backend.fetch("SELECT * FROM samples.nyctaxi.trips LIMIT 10")) diff --git a/tests/integration/test_core.py b/tests/integration/test_core.py index 375a03b7..865d9364 100644 --- a/tests/integration/test_core.py +++ b/tests/integration/test_core.py @@ -1,6 +1,8 @@ import logging import pytest +from databricks.labs.blueprint.parallel import Threads +from databricks.sdk.errors import BadRequest from databricks.sdk.service.sql import Disposition from databricks.labs.lsql.core import Row, StatementExecutionExt @@ -83,3 +85,18 @@ def test_fetch_value(ws): see = StatementExecutionExt(ws) count = see.fetch_value("SELECT COUNT(*) FROM samples.nyctaxi.trips") assert count == 21932 + + +def test_runtime_backend_handles_concurrent_append(sql_backend, make_random, make_table) -> None: + table = make_table(name=f"lsql_test_{make_random()}", ctas="SELECT r.id AS x, random() AS y FROM range(1000000) r") + + def update_table() -> None: + sql_backend.execute(f"UPDATE {table.full_name} SET y = y * 2 WHERE (x % 2 = 0)") + + try: + Threads.strict("concurrent appends", [update_table, update_table]) + except BadRequest as e: + if "DELTA_CONCURRENT_APPEND" in str(e): + assert False, str(e) + else: + raise # Raise in case of unexpected error diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py index 71db5207..04700be0 100644 --- a/tests/unit/test_backends.py +++ b/tests/unit/test_backends.py @@ -1,5 +1,6 @@ import datetime import os +import re import sys from dataclasses import dataclass from unittest import mock @@ -32,6 +33,7 @@ RuntimeBackend, StatementExecutionBackend, ) +from databricks.labs.lsql.core import DeltaConcurrentAppend # pylint: disable=protected-access @@ -364,9 +366,10 @@ def test_save_table_with_not_null_constraint_violated(): ("PARSE_SYNTAX_ERROR foo", BadRequest), ("foo Operation not allowed", PermissionDenied), ("foo error failure", Unknown), + ("[DELTA_CONCURRENT_APPEND] ConcurrentAppendException: Files were added ...", DeltaConcurrentAppend), ], ) -def test_runtime_backend_error_mapping_similar_to_statement_execution(msg, err_t): +def test_runtime_backend_error_mapping_similar_to_statement_execution(msg, err_t) -> None: with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): pyspark_sql_session = MagicMock() sys.modules["pyspark.sql.session"] = pyspark_sql_session @@ -376,10 +379,10 @@ def test_runtime_backend_error_mapping_similar_to_statement_execution(msg, err_t runtime_backend = RuntimeBackend() - with pytest.raises(err_t): + with pytest.raises(err_t, match=re.escape(msg)): runtime_backend.execute("SELECT * from bar") - with pytest.raises(err_t): + with pytest.raises(err_t, match=re.escape(msg)): list(runtime_backend.fetch("SELECT * from bar")) diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index 18e93549..a11eb593 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -22,7 +22,7 @@ timedelta, ) -from databricks.labs.lsql.core import Row, StatementExecutionExt +from databricks.labs.lsql.core import DeltaConcurrentAppend, Row, StatementExecutionExt @pytest.mark.parametrize( @@ -196,19 +196,18 @@ def test_execute_poll_succeeds(): (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), (ServiceError(message="... DELTA_MISSING_TRANSACTION_LOG ..."), errors.DataLoss), + (ServiceError(message="... DELTA_CONCURRENT_APPEND ..."), DeltaConcurrentAppend), ], ) -def test_execute_fails(status_error, platform_error_type): +def test_execute_fails(status_error, platform_error_type) -> None: ws = create_autospec(WorkspaceClient) - ws.statement_execution.execute_statement.return_value = StatementResponse( status=StatementStatus(state=StatementState.FAILED, error=status_error), statement_id="bcd", ) - see = StatementExecutionExt(ws, warehouse_id="abc") - with pytest.raises(platform_error_type): + with pytest.raises(platform_error_type, match=status_error.message if status_error is not None else None): see.execute("SELECT 2+2")