Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions sqlspec/builder/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def update(self, table_or_sql: str | None = None, dialect: DialectType = None) -

def delete(self, table_or_sql: str | None = None, dialect: DialectType = None) -> "Delete":
builder_dialect = dialect or self.dialect
builder = Delete(dialect=builder_dialect)
if table_or_sql and self._looks_like_sql(table_or_sql):
builder = Delete(dialect=builder_dialect)
detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
if detected != "DELETE":
msg = (
Expand All @@ -345,23 +345,23 @@ def delete(self, table_or_sql: str | None = None, dialect: DialectType = None) -
)
raise SQLBuilderError(msg)
return self._populate_delete_from_sql(builder, table_or_sql)
return builder

return Delete(table_or_sql, dialect=builder_dialect) if table_or_sql else Delete(dialect=builder_dialect)

def merge(self, table_or_sql: str | None = None, dialect: DialectType = None) -> "Merge":
builder_dialect = dialect or self.dialect
builder = Merge(dialect=builder_dialect)
if table_or_sql:
if self._looks_like_sql(table_or_sql):
detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
if detected != "MERGE":
msg = (
f"sql.merge() expects MERGE statement, got {detected}. "
f"Use sql.{detected.lower()}() if a dedicated builder exists."
)
raise SQLBuilderError(msg)
return self._populate_merge_from_sql(builder, table_or_sql)
return builder.into(table_or_sql)
return builder
if table_or_sql and self._looks_like_sql(table_or_sql):
builder = Merge(dialect=builder_dialect)
detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
if detected != "MERGE":
msg = (
f"sql.merge() expects MERGE statement, got {detected}. "
f"Use sql.{detected.lower()}() if a dedicated builder exists."
)
raise SQLBuilderError(msg)
return self._populate_merge_from_sql(builder, table_or_sql)

return Merge(table_or_sql, dialect=builder_dialect) if table_or_sql else Merge(dialect=builder_dialect)

@property
def merge_(self) -> "Merge":
Expand Down
4 changes: 4 additions & 0 deletions sqlspec/driver/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,9 +1192,13 @@ def _create_count_query(self, original_sql: "SQL") -> "SQL":
Transforms the original SELECT statement to count total rows while preserving
WHERE, HAVING, and GROUP BY clauses but removing ORDER BY, LIMIT, and OFFSET.
"""
if not original_sql.expression:
original_sql.compile()

if not original_sql.expression:
msg = "Cannot create COUNT query from empty SQL expression"
raise ImproperConfigurationError(msg)

expr = original_sql.expression

if isinstance(expr, exp.Select):
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_builder/test_merge_factory_table_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Factory convenience tests for merge(table)."""

import pytest

from sqlspec.builder import sql
from sqlspec.exceptions import SQLBuilderError


def test_merge_factory_sets_target_table_from_positional_arg() -> None:
"""sql.merge(table) should set INTO target without separate into()."""

query = sql.merge("products").using("staging", alias="s").on("products.id = s.id").when_matched_then_delete()

stmt = query.build()

assert "products" in stmt.sql.lower()
assert "merge" in stmt.sql.lower()


def test_merge_factory_rejects_non_merge_sql() -> None:
"""sql.merge() with non-MERGE SQL should raise helpful error."""

bad_sql = "SELECT * FROM products"

with pytest.raises(SQLBuilderError):
sql.merge(bad_sql)
13 changes: 13 additions & 0 deletions tests/unit/test_builder/test_to_sql_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ def test_to_sql_delete_statement() -> None:
assert "DELETE" in sql_str


def test_to_sql_delete_statement_with_table_arg() -> None:
"""Test sql.delete(table) sets target table without explicit from()."""

query = sql.delete("products").where("id = :id")
query.add_parameter(999, "id")

sql_str = query.to_sql(show_parameters=True)

assert "products" in sql_str.lower()
assert "delete" in sql_str.lower()
assert "999" in sql_str


def test_to_sql_same_parameter_name_multiple_times() -> None:
"""Test to_sql() handles same parameter name used multiple times."""
query = sql.select("*").from_("products").where("price >= :threshold").where("discount >= :threshold")
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/test_driver/test_create_count_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Tests for _create_count_query parsing behavior."""

from typing import Any, cast

from sqlspec import SQL
from sqlspec.core import get_default_config
from sqlspec.driver import CommonDriverAttributesMixin


class _CountDriver(CommonDriverAttributesMixin):
"""Minimal driver exposing _create_count_query for testing."""

__slots__ = ()

def __init__(self) -> None:
super().__init__(connection=None, statement_config=get_default_config())


def test_create_count_query_compiles_missing_expression() -> None:
"""Ensure count query generation parses SQL lacking prebuilt expression."""

driver = _CountDriver()
sql_statement = SQL("SELECT id FROM users WHERE active = true")

assert sql_statement.expression is None

count_sql = cast("Any", driver)._create_count_query(sql_statement)

assert sql_statement.expression is not None

compiled_sql, _ = count_sql.compile()

assert count_sql.expression is not None
assert "count" in compiled_sql.lower()
21 changes: 21 additions & 0 deletions tests/unit/test_loader/test_sql_file_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,27 @@ def test_parse_normalize_query_names() -> None:
assert "update_user_email" in statements


def test_get_sql_parses_expression_when_missing() -> None:
"""SQL objects from get_sql should carry parsed expressions for count queries."""

loader = SQLFileLoader()
content = """
-- name: list_users
SELECT id, email FROM user_account WHERE active = true;
"""

statements = SQLFileLoader._parse_sql_content(content, "test.sql")
loader._queries = statements

sql_obj = loader.get_sql("list_users")

assert sql_obj.expression is None

sql_obj.compile()

assert sql_obj.expression is not None


def test_parse_skips_files_without_named_statements() -> None:
"""Test that files without named statements return empty dict."""
content = "SELECT * FROM users;"
Expand Down
Loading