Skip to content

Commit 64476a2

Browse files
authored
fix(builder): correctly handle edge cases (#284)
Enhance SQL builder methods to correctly handle edge cases for delete and merge operations. Add tests to ensure proper functionality and error handling for invalid SQL inputs. Improve count query generation to compile missing expressions.
1 parent 92ebb0e commit 64476a2

File tree

7 files changed

+334
-105
lines changed

7 files changed

+334
-105
lines changed

sqlspec/builder/_factory.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def update(self, table_or_sql: str | None = None, dialect: DialectType = None) -
335335

336336
def delete(self, table_or_sql: str | None = None, dialect: DialectType = None) -> "Delete":
337337
builder_dialect = dialect or self.dialect
338-
builder = Delete(dialect=builder_dialect)
339338
if table_or_sql and self._looks_like_sql(table_or_sql):
339+
builder = Delete(dialect=builder_dialect)
340340
detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
341341
if detected != "DELETE":
342342
msg = (
@@ -345,23 +345,23 @@ def delete(self, table_or_sql: str | None = None, dialect: DialectType = None) -
345345
)
346346
raise SQLBuilderError(msg)
347347
return self._populate_delete_from_sql(builder, table_or_sql)
348-
return builder
348+
349+
return Delete(table_or_sql, dialect=builder_dialect) if table_or_sql else Delete(dialect=builder_dialect)
349350

350351
def merge(self, table_or_sql: str | None = None, dialect: DialectType = None) -> "Merge":
351352
builder_dialect = dialect or self.dialect
352-
builder = Merge(dialect=builder_dialect)
353-
if table_or_sql:
354-
if self._looks_like_sql(table_or_sql):
355-
detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
356-
if detected != "MERGE":
357-
msg = (
358-
f"sql.merge() expects MERGE statement, got {detected}. "
359-
f"Use sql.{detected.lower()}() if a dedicated builder exists."
360-
)
361-
raise SQLBuilderError(msg)
362-
return self._populate_merge_from_sql(builder, table_or_sql)
363-
return builder.into(table_or_sql)
364-
return builder
353+
if table_or_sql and self._looks_like_sql(table_or_sql):
354+
builder = Merge(dialect=builder_dialect)
355+
detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
356+
if detected != "MERGE":
357+
msg = (
358+
f"sql.merge() expects MERGE statement, got {detected}. "
359+
f"Use sql.{detected.lower()}() if a dedicated builder exists."
360+
)
361+
raise SQLBuilderError(msg)
362+
return self._populate_merge_from_sql(builder, table_or_sql)
363+
364+
return Merge(table_or_sql, dialect=builder_dialect) if table_or_sql else Merge(dialect=builder_dialect)
365365

366366
@property
367367
def merge_(self) -> "Merge":

sqlspec/driver/_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,9 +1192,13 @@ def _create_count_query(self, original_sql: "SQL") -> "SQL":
11921192
Transforms the original SELECT statement to count total rows while preserving
11931193
WHERE, HAVING, and GROUP BY clauses but removing ORDER BY, LIMIT, and OFFSET.
11941194
"""
1195+
if not original_sql.expression:
1196+
original_sql.compile()
1197+
11951198
if not original_sql.expression:
11961199
msg = "Cannot create COUNT query from empty SQL expression"
11971200
raise ImproperConfigurationError(msg)
1201+
11981202
expr = original_sql.expression
11991203

12001204
if isinstance(expr, exp.Select):
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Factory convenience tests for merge(table)."""
2+
3+
import pytest
4+
5+
from sqlspec.builder import sql
6+
from sqlspec.exceptions import SQLBuilderError
7+
8+
9+
def test_merge_factory_sets_target_table_from_positional_arg() -> None:
10+
"""sql.merge(table) should set INTO target without separate into()."""
11+
12+
query = sql.merge("products").using("staging", alias="s").on("products.id = s.id").when_matched_then_delete()
13+
14+
stmt = query.build()
15+
16+
assert "products" in stmt.sql.lower()
17+
assert "merge" in stmt.sql.lower()
18+
19+
20+
def test_merge_factory_rejects_non_merge_sql() -> None:
21+
"""sql.merge() with non-MERGE SQL should raise helpful error."""
22+
23+
bad_sql = "SELECT * FROM products"
24+
25+
with pytest.raises(SQLBuilderError):
26+
sql.merge(bad_sql)

tests/unit/test_builder/test_to_sql_edge_cases.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,19 @@ def test_to_sql_delete_statement() -> None:
160160
assert "DELETE" in sql_str
161161

162162

163+
def test_to_sql_delete_statement_with_table_arg() -> None:
164+
"""Test sql.delete(table) sets target table without explicit from()."""
165+
166+
query = sql.delete("products").where("id = :id")
167+
query.add_parameter(999, "id")
168+
169+
sql_str = query.to_sql(show_parameters=True)
170+
171+
assert "products" in sql_str.lower()
172+
assert "delete" in sql_str.lower()
173+
assert "999" in sql_str
174+
175+
163176
def test_to_sql_same_parameter_name_multiple_times() -> None:
164177
"""Test to_sql() handles same parameter name used multiple times."""
165178
query = sql.select("*").from_("products").where("price >= :threshold").where("discount >= :threshold")
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Tests for _create_count_query parsing behavior."""
2+
3+
from typing import Any, cast
4+
5+
from sqlspec import SQL
6+
from sqlspec.core import get_default_config
7+
from sqlspec.driver import CommonDriverAttributesMixin
8+
9+
10+
class _CountDriver(CommonDriverAttributesMixin):
11+
"""Minimal driver exposing _create_count_query for testing."""
12+
13+
__slots__ = ()
14+
15+
def __init__(self) -> None:
16+
super().__init__(connection=None, statement_config=get_default_config())
17+
18+
19+
def test_create_count_query_compiles_missing_expression() -> None:
20+
"""Ensure count query generation parses SQL lacking prebuilt expression."""
21+
22+
driver = _CountDriver()
23+
sql_statement = SQL("SELECT id FROM users WHERE active = true")
24+
25+
assert sql_statement.expression is None
26+
27+
count_sql = cast("Any", driver)._create_count_query(sql_statement)
28+
29+
assert sql_statement.expression is not None
30+
31+
compiled_sql, _ = count_sql.compile()
32+
33+
assert count_sql.expression is not None
34+
assert "count" in compiled_sql.lower()

tests/unit/test_loader/test_sql_file_loader.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,27 @@ def test_parse_normalize_query_names() -> None:
205205
assert "update_user_email" in statements
206206

207207

208+
def test_get_sql_parses_expression_when_missing() -> None:
209+
"""SQL objects from get_sql should carry parsed expressions for count queries."""
210+
211+
loader = SQLFileLoader()
212+
content = """
213+
-- name: list_users
214+
SELECT id, email FROM user_account WHERE active = true;
215+
"""
216+
217+
statements = SQLFileLoader._parse_sql_content(content, "test.sql")
218+
loader._queries = statements
219+
220+
sql_obj = loader.get_sql("list_users")
221+
222+
assert sql_obj.expression is None
223+
224+
sql_obj.compile()
225+
226+
assert sql_obj.expression is not None
227+
228+
208229
def test_parse_skips_files_without_named_statements() -> None:
209230
"""Test that files without named statements return empty dict."""
210231
content = "SELECT * FROM users;"

0 commit comments

Comments
 (0)