Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/examples/adbc_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def adbc_example() -> None:
# Create SQLSpec instance with ADBC (connects to dev PostgreSQL container)
spec = SQLSpec()
config = AdbcConfig(connection_config={"uri": "postgresql://postgres:postgres@localhost:5433/postgres"})
spec.add_config(config)
db = spec.add_config(config)

# Get a driver directly (drivers now have built-in query methods)
with spec.provide_session(config) as driver:
with spec.provide_session(db) as driver:
# Create a table
driver.execute("""
CREATE TABLE IF NOT EXISTS analytics_data (
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
name = "sqlspec"
readme = "README.md"
requires-python = ">=3.9, <4.0"
version = "0.17.0"
version = "0.17.1"

[project.urls]
Discord = "https://discord.gg/litestar"
Expand Down
215 changes: 137 additions & 78 deletions sqlspec/_sql.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion sqlspec/builder/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from collections.abc import Iterable
from typing import Any, Optional
from typing import Any, Optional, cast

from sqlglot import exp

Expand Down Expand Up @@ -241,6 +241,10 @@ def desc(self) -> exp.Ordered:
"""Create a DESC ordering expression."""
return exp.Ordered(this=self._expression, desc=True)

def as_(self, alias: str) -> exp.Alias:
"""Create an aliased expression."""
return cast("exp.Alias", exp.alias_(self._expression, alias))

def __repr__(self) -> str:
if self.table:
return f"Column<{self.table}.{self.name}>"
Expand Down
46 changes: 46 additions & 0 deletions sqlspec/builder/_expression_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Expression wrapper classes for proper type annotations."""

from typing import cast

from sqlglot import exp

__all__ = ("AggregateExpression", "ConversionExpression", "FunctionExpression", "MathExpression", "StringExpression")


class ExpressionWrapper:
"""Base wrapper for SQLGlot expressions."""

def __init__(self, expression: exp.Expression) -> None:
self._expression = expression

def as_(self, alias: str) -> exp.Alias:
"""Create an aliased expression."""
return cast("exp.Alias", exp.alias_(self._expression, alias))

@property
def expression(self) -> exp.Expression:
"""Get the underlying SQLGlot expression."""
return self._expression

def __str__(self) -> str:
return str(self._expression)


class AggregateExpression(ExpressionWrapper):
"""Aggregate functions like COUNT, SUM, AVG."""


class FunctionExpression(ExpressionWrapper):
"""General SQL functions."""


class MathExpression(ExpressionWrapper):
"""Mathematical functions like ROUND."""


class StringExpression(ExpressionWrapper):
"""String functions like UPPER, LOWER, LENGTH."""


class ConversionExpression(ExpressionWrapper):
"""Conversion functions like CAST, COALESCE."""
4 changes: 1 addition & 3 deletions sqlspec/builder/_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,7 @@ def do_update(self, **kwargs: Any) -> "Insert":
# Create ON CONFLICT with proper structure
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
on_conflict = exp.OnConflict(
conflict_keys=conflict_keys,
action=exp.var("DO UPDATE"),
expressions=set_expressions if set_expressions else None,
conflict_keys=conflict_keys, action=exp.var("DO UPDATE"), expressions=set_expressions or None
)

insert_expr.set("conflict", on_conflict)
Expand Down
10 changes: 5 additions & 5 deletions sqlspec/builder/_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,26 @@ class Update(
update_query = (
Update()
.table("users")
.set(name="John Doe")
.set(email="[email protected]")
.set_(name="John Doe")
.set_(email="[email protected]")
.where("id = 1")
)

update_query = (
Update("users").set(name="John Doe").where("id = 1")
Update("users").set_(name="John Doe").where("id = 1")
)

update_query = (
Update()
.table("users")
.set(status="active")
.set_(status="active")
.where_eq("id", 123)
)

update_query = (
Update()
.table("users", "u")
.set(name="Updated Name")
.set_(name="Updated Name")
.from_("profiles", "p")
.where("u.id = p.user_id AND p.is_verified = true")
)
Expand Down
20 changes: 16 additions & 4 deletions sqlspec/builder/mixins/_order_limit_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from sqlspec.exceptions import SQLBuilderError

if TYPE_CHECKING:
from sqlspec.builder._column import Column
from sqlspec.builder._expression_wrappers import ExpressionWrapper
from sqlspec.builder.mixins._select_operations import Case
from sqlspec.protocols import SQLBuilderProtocol

__all__ = ("LimitOffsetClauseMixin", "OrderByClauseMixin", "ReturningClauseMixin")
Expand All @@ -24,7 +27,7 @@ class OrderByClauseMixin:
# Type annotation for PyRight - this will be provided by the base class
_expression: Optional[exp.Expression]

def order_by(self, *items: Union[str, exp.Ordered], desc: bool = False) -> Self:
def order_by(self, *items: Union[str, exp.Ordered, "Column"], desc: bool = False) -> Self:
"""Add ORDER BY clause.

Args:
Expand All @@ -49,7 +52,13 @@ def order_by(self, *items: Union[str, exp.Ordered], desc: bool = False) -> Self:
if desc:
order_item = order_item.desc()
else:
order_item = item
# Extract expression from Column objects or use as-is for sqlglot expressions
from sqlspec._sql import SQLFactory

extracted_item = SQLFactory._extract_expression(item)
order_item = extracted_item
if desc and not isinstance(item, exp.Ordered):
order_item = order_item.desc()
current_expr = current_expr.order_by(order_item, copy=False)
builder._expression = current_expr
return cast("Self", builder)
Expand Down Expand Up @@ -111,7 +120,7 @@ class ReturningClauseMixin:
# Type annotation for PyRight - this will be provided by the base class
_expression: Optional[exp.Expression]

def returning(self, *columns: Union[str, exp.Expression]) -> Self:
def returning(self, *columns: Union[str, exp.Expression, "Column", "ExpressionWrapper", "Case"]) -> Self:
"""Add RETURNING clause to the statement.

Args:
Expand All @@ -130,6 +139,9 @@ def returning(self, *columns: Union[str, exp.Expression]) -> Self:
if not isinstance(self._expression, valid_types):
msg = "RETURNING is only supported for INSERT, UPDATE, and DELETE statements."
raise SQLBuilderError(msg)
returning_exprs = [exp.column(c) if isinstance(c, str) else c for c in columns]
# Extract expressions from various wrapper types
from sqlspec._sql import SQLFactory

returning_exprs = [SQLFactory._extract_expression(c) for c in columns]
self._expression.set("returning", exp.Returning(expressions=returning_exprs))
return self
4 changes: 2 additions & 2 deletions sqlspec/builder/mixins/_select_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def when(self, condition: Union[str, exp.Expression], value: Union[str, exp.Expr
from sqlspec._sql import SQLFactory

cond_expr = exp.maybe_parse(condition) or exp.column(condition) if isinstance(condition, str) else condition
val_expr = SQLFactory._to_literal(value)
val_expr = SQLFactory._to_expression(value)

# SQLGlot uses exp.If for CASE WHEN clauses, not exp.When
when_clause = exp.If(this=cond_expr, true=val_expr)
Expand All @@ -876,7 +876,7 @@ def else_(self, value: Union[str, exp.Expression, Any]) -> Self:
"""
from sqlspec._sql import SQLFactory

self._default = SQLFactory._to_literal(value)
self._default = SQLFactory._to_expression(value)
return self

def end(self) -> Self:
Expand Down
8 changes: 4 additions & 4 deletions sqlspec/builder/mixins/_update_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def set(self, *args: Any, **kwargs: Any) -> Self:
"""Set columns and values for the UPDATE statement.

Supports:
- set(column, value)
- set(mapping)
- set(**kwargs)
- set(mapping, **kwargs)
- set_(column, value)
- set_(mapping)
- set_(**kwargs)
- set_(mapping, **kwargs)

Args:
*args: Either (column, value) or a mapping.
Expand Down
10 changes: 10 additions & 0 deletions sqlspec/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
__all__ = (
"BytesConvertibleProtocol",
"DictProtocol",
"ExpressionWithAliasProtocol",
"FilterAppenderProtocol",
"FilterParameterProtocol",
"HasExpressionProtocol",
Expand Down Expand Up @@ -172,6 +173,15 @@ def __bytes__(self) -> bytes:
...


@runtime_checkable
class ExpressionWithAliasProtocol(Protocol):
"""Protocol for SQL expressions that support aliasing with as_() method."""

def as_(self, alias: str, **kwargs: Any) -> "exp.Alias":
"""Create an aliased expression."""
...


@runtime_checkable
class ObjectStoreItemProtocol(Protocol):
"""Protocol for object store items with path/key attributes."""
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_builder/test_parameter_naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
- Edge cases and error conditions
"""

import string

from sqlspec import sql


Expand Down Expand Up @@ -329,7 +331,7 @@ def test_parameter_names_are_sql_safe() -> None:
assert "--" not in param_name

# Should be valid identifier-like
assert param_name.replace("_", "").replace("0123456789", "").isalpha() or "_" in param_name
assert param_name.replace("_", "").replace(string.digits, "").isalpha() or "_" in param_name


def test_empty_and_null_values_preserve_column_names() -> None:
Expand Down
44 changes: 36 additions & 8 deletions tests/unit/test_sql_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,35 +416,63 @@ def test_all_ddl_methods_exist() -> None:

def test_count_function() -> None:
"""Test sql.count() function."""
from sqlspec.builder._expression_wrappers import AggregateExpression

expr = sql.count()
assert isinstance(expr, exp.Expression)
assert isinstance(expr, AggregateExpression)
assert hasattr(expr, "as_")
assert hasattr(expr, "expression")
assert isinstance(expr.expression, exp.Expression)

count_column = sql.count("user_id")
assert isinstance(count_column, exp.Expression)
assert isinstance(count_column, AggregateExpression)
assert hasattr(count_column, "as_")
assert hasattr(count_column, "expression")
assert isinstance(count_column.expression, exp.Expression)


def test_sum_function() -> None:
"""Test sql.sum() function."""
from sqlspec.builder._expression_wrappers import AggregateExpression

expr = sql.sum("amount")
assert isinstance(expr, exp.Expression)
assert isinstance(expr, AggregateExpression)
assert hasattr(expr, "as_")
assert hasattr(expr, "expression")
assert isinstance(expr.expression, exp.Expression)


def test_avg_function() -> None:
"""Test sql.avg() function."""
from sqlspec.builder._expression_wrappers import AggregateExpression

expr = sql.avg("score")
assert isinstance(expr, exp.Expression)
assert isinstance(expr, AggregateExpression)
assert hasattr(expr, "as_")
assert hasattr(expr, "expression")
assert isinstance(expr.expression, exp.Expression)


def test_max_function() -> None:
"""Test sql.max() function."""
from sqlspec.builder._expression_wrappers import AggregateExpression

expr = sql.max("created_at")
assert isinstance(expr, exp.Expression)
assert isinstance(expr, AggregateExpression)
assert hasattr(expr, "as_")
assert hasattr(expr, "expression")
assert isinstance(expr.expression, exp.Expression)


def test_min_function() -> None:
"""Test sql.min() function."""
from sqlspec.builder._expression_wrappers import AggregateExpression

expr = sql.min("price")
assert isinstance(expr, exp.Expression)
assert isinstance(expr, AggregateExpression)
assert hasattr(expr, "as_")
assert hasattr(expr, "expression")
assert isinstance(expr.expression, exp.Expression)


def test_column_method() -> None:
Expand Down Expand Up @@ -1074,7 +1102,7 @@ def test_type_compatibility_across_all_operations() -> None:


def test_update_set_method_with_sql_objects() -> None:
"""Test that UPDATE.set() method properly handles SQL objects with kwargs."""
"""Test that UPDATE.set_() method properly handles SQL objects with kwargs."""
raw_timestamp = sql.raw("NOW()")
raw_computed = sql.raw("UPPER(:value)", value="test")

Expand All @@ -1097,7 +1125,7 @@ def test_update_set_method_with_sql_objects() -> None:


def test_update_set_method_backward_compatibility() -> None:
"""Test that UPDATE.set() method maintains backward compatibility with dict."""
"""Test that UPDATE.set_() method maintains backward compatibility with dict."""
raw_timestamp = sql.raw("NOW()")

# Test using dict (original API)
Expand Down
Loading