Skip to content

Commit a2e3e0b

Browse files
authored
fix: returns_row return False incorrectly (#263)
Corrects an issue where operation type incorrectly report a result would not return rows under certain conditions.
1 parent 93fbb4b commit a2e3e0b

File tree

9 files changed

+151
-7
lines changed

9 files changed

+151
-7
lines changed

sqlspec/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ParameterProcessor,
2929
ParameterStyle,
3030
ParameterStyleConfig,
31+
ProcessedState,
3132
SQLResult,
3233
StackOperation,
3334
StackResult,
@@ -66,6 +67,7 @@
6667
"ParameterStyle",
6768
"ParameterStyleConfig",
6869
"PoolT",
70+
"ProcessedState",
6971
"QueryBuilder",
7072
"SQLFactory",
7173
"SQLFile",

sqlspec/adapters/adbc/driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,9 @@ def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResu
513513
self._handle_postgres_rollback(cursor)
514514
raise
515515

516-
if statement.returns_rows():
516+
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)
517+
518+
if is_select_like:
517519
fetched_data = cursor.fetchall()
518520
column_names = [col[0] for col in cursor.description or []]
519521

sqlspec/adapters/bigquery/driver.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,9 +686,13 @@ def _execute_statement(self, cursor: Any, statement: "SQL") -> ExecutionResult:
686686
"""
687687
sql, parameters = self._get_compiled_sql(statement, self.statement_config)
688688
cursor.job = self._run_query_job(sql, parameters, connection=cursor)
689+
job_result = cursor.job.result(job_retry=self._job_retry)
690+
statement_type = str(cursor.job.statement_type or "").upper()
691+
is_select_like = (
692+
statement.returns_rows() or statement_type == "SELECT" or self._should_force_select(statement, cursor)
693+
)
689694

690-
if statement.returns_rows():
691-
job_result = cursor.job.result(job_retry=self._job_retry)
695+
if is_select_like:
692696
rows_list = self._rows_to_results(iter(job_result))
693697
column_names = [field.name for field in cursor.job.schema] if cursor.job.schema else []
694698

@@ -700,7 +704,6 @@ def _execute_statement(self, cursor: Any, statement: "SQL") -> ExecutionResult:
700704
is_select_result=True,
701705
)
702706

703-
cursor.job.result(job_retry=self._job_retry)
704707
affected_rows = cursor.job.num_dml_affected_rows or 0
705708
return self.create_execution_result(cursor, rowcount_override=affected_rows)
706709

sqlspec/adapters/duckdb/driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,9 @@ def _execute_statement(self, cursor: Any, statement: SQL) -> "ExecutionResult":
349349
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
350350
cursor.execute(sql, prepared_parameters or ())
351351

352-
if statement.returns_rows():
352+
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)
353+
354+
if is_select_like:
353355
fetched_data = cursor.fetchall()
354356
column_names = [col[0] for col in cursor.description or []]
355357

sqlspec/adapters/oracledb/driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,9 @@ async def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionR
12431243
await cursor.execute(sql, prepared_parameters or {})
12441244

12451245
# SELECT result processing for Oracle
1246-
if statement.returns_rows():
1246+
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)
1247+
1248+
if is_select_like:
12471249
fetched_data = await cursor.fetchall()
12481250
column_names = [col[0] for col in cursor.description or []]
12491251
column_names = _normalize_column_names(column_names, self.driver_features)

sqlspec/core/compiler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@
5555

5656
OPERATION_TYPE_MAP: "dict[type[exp.Expression], OperationType]" = {
5757
exp.Select: "SELECT",
58+
exp.Union: "SELECT",
59+
exp.Except: "SELECT",
60+
exp.Intersect: "SELECT",
61+
exp.With: "SELECT",
5862
exp.Insert: "INSERT",
5963
exp.Update: "UPDATE",
6064
exp.Delete: "DELETE",
@@ -554,7 +558,9 @@ def _build_operation_profile(
554558
modifies_rows = False
555559

556560
expr = expression
557-
if isinstance(expr, (exp.Select, exp.Values, exp.Table, exp.TableSample, exp.With)):
561+
if isinstance(
562+
expr, (exp.Select, exp.Union, exp.Except, exp.Intersect, exp.Values, exp.Table, exp.TableSample, exp.With)
563+
):
558564
returns_rows = True
559565
elif isinstance(expr, (exp.Insert, exp.Update, exp.Delete, exp.Merge)):
560566
modifies_rows = True

sqlspec/driver/_common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,32 @@ def build_statement_result(self, statement: "SQL", execution_result: ExecutionRe
548548
metadata=execution_result.special_data or {"status_message": "OK"},
549549
)
550550

551+
def _should_force_select(self, statement: "SQL", cursor: Any) -> bool:
552+
"""Determine if a statement with unknown type should be treated as SELECT.
553+
554+
Uses driver metadata (statement_type, description/schema) as a safety net when
555+
the compiler cannot classify the operation. This remains conservative by only
556+
triggering when the operation type is "UNKNOWN".
557+
558+
Args:
559+
statement: SQL statement being executed.
560+
cursor: Database cursor/job object that may expose metadata.
561+
562+
Returns:
563+
True when cursor metadata indicates a row-returning operation despite an
564+
unknown operation type; otherwise False.
565+
"""
566+
567+
if statement.operation_type != "UNKNOWN":
568+
return False
569+
570+
statement_type = getattr(cursor, "statement_type", None)
571+
if isinstance(statement_type, str) and statement_type.upper() == "SELECT":
572+
return True
573+
574+
description = getattr(cursor, "description", None)
575+
return bool(description)
576+
551577
def prepare_statement(
552578
self,
553579
statement: "Statement | QueryBuilder",

tests/unit/test_core/test_statement.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,27 @@ def test_sql_returns_rows_detection() -> None:
534534
assert show_stmt.returns_rows() is True
535535

536536

537+
@pytest.mark.parametrize(
538+
"sql_text",
539+
[
540+
"SELECT 1 UNION ALL SELECT 2",
541+
"SELECT 1 EXCEPT SELECT 2",
542+
"SELECT 1 INTERSECT SELECT 1",
543+
"WITH cte AS (SELECT 1 AS id) SELECT * FROM cte",
544+
],
545+
ids=["union", "except", "intersect", "cte_select"],
546+
)
547+
def test_sql_set_and_cte_operations_detect_as_select(sql_text: str) -> None:
548+
"""Ensure set operations and CTE queries are detected as SELECT and return rows."""
549+
550+
stmt = SQL(sql_text)
551+
stmt.compile()
552+
553+
assert stmt.operation_type == "SELECT"
554+
assert stmt.returns_rows() is True
555+
assert stmt.is_modifying_operation() is False
556+
557+
537558
def test_sql_slots_prevent_new_attributes() -> None:
538559
"""Test SQL __slots__ prevent adding new attributes."""
539560
stmt = SQL("SELECT * FROM users")
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Tests for the _should_force_select safety net."""
2+
3+
from typing import Any, cast
4+
5+
from sqlspec import SQL, ProcessedState
6+
from sqlspec.adapters.bigquery import bigquery_statement_config
7+
from sqlspec.driver import CommonDriverAttributesMixin
8+
9+
10+
class _DummyDriver(CommonDriverAttributesMixin):
11+
"""Minimal driver to expose _should_force_select for testing."""
12+
13+
__slots__ = ()
14+
15+
def __init__(self) -> None:
16+
super().__init__(connection=None, statement_config=bigquery_statement_config)
17+
18+
19+
class _CursorWithStatementType:
20+
"""Cursor exposing a statement_type attribute."""
21+
22+
def __init__(self, statement_type: str | None) -> None:
23+
self.statement_type = statement_type
24+
self.description = None
25+
26+
27+
class _CursorWithDescription:
28+
"""Cursor exposing a description attribute."""
29+
30+
def __init__(self, has_description: bool) -> None:
31+
self.description = [("col",)] if has_description else None
32+
self.statement_type = None
33+
34+
35+
def _make_unknown_statement(sql_text: str = "select 1") -> "SQL":
36+
stmt = SQL(sql_text)
37+
cast("Any", stmt)._processed_state = ProcessedState(
38+
compiled_sql=sql_text, execution_parameters={}, operation_type="UNKNOWN"
39+
)
40+
return stmt
41+
42+
43+
def _make_select_statement(sql_text: str = "select 1") -> "SQL":
44+
stmt = SQL(sql_text)
45+
cast("Any", stmt)._processed_state = ProcessedState(
46+
compiled_sql=sql_text, execution_parameters={}, operation_type="SELECT"
47+
)
48+
return stmt
49+
50+
51+
def test_force_select_uses_statement_type_select() -> None:
52+
driver = _DummyDriver()
53+
stmt = _make_unknown_statement()
54+
cursor = _CursorWithStatementType("SELECT")
55+
56+
assert cast("Any", driver)._should_force_select(stmt, cursor) is True
57+
58+
59+
def test_force_select_uses_description_when_unknown() -> None:
60+
driver = _DummyDriver()
61+
stmt = _make_unknown_statement()
62+
cursor = _CursorWithDescription(True)
63+
64+
assert cast("Any", driver)._should_force_select(stmt, cursor) is True
65+
66+
67+
def test_force_select_false_when_no_metadata() -> None:
68+
driver = _DummyDriver()
69+
stmt = _make_unknown_statement()
70+
cursor = _CursorWithDescription(False)
71+
72+
assert cast("Any", driver)._should_force_select(stmt, cursor) is False
73+
74+
75+
def test_force_select_ignored_when_operation_known() -> None:
76+
driver = _DummyDriver()
77+
stmt = _make_select_statement()
78+
cursor = _CursorWithDescription(True)
79+
80+
assert cast("Any", driver)._should_force_select(stmt, cursor) is False

0 commit comments

Comments
 (0)