Skip to content

Commit 2ade6a1

Browse files
authored
fix: improve column and returning type hints (#56)
Additional type hints for columns and returning statements in the builder.
1 parent 0494fa3 commit 2ade6a1

File tree

14 files changed

+369
-210
lines changed

14 files changed

+369
-210
lines changed

docs/examples/adbc_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def adbc_example() -> None:
1616
# Create SQLSpec instance with ADBC (connects to dev PostgreSQL container)
1717
spec = SQLSpec()
1818
config = AdbcConfig(connection_config={"uri": "postgresql://postgres:postgres@localhost:5433/postgres"})
19-
spec.add_config(config)
19+
db = spec.add_config(config)
2020

2121
# Get a driver directly (drivers now have built-in query methods)
22-
with spec.provide_session(config) as driver:
22+
with spec.provide_session(db) as driver:
2323
# Create a table
2424
driver.execute("""
2525
CREATE TABLE IF NOT EXISTS analytics_data (

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
1313
name = "sqlspec"
1414
readme = "README.md"
1515
requires-python = ">=3.9, <4.0"
16-
version = "0.17.0"
16+
version = "0.17.1"
1717

1818
[project.urls]
1919
Discord = "https://discord.gg/litestar"

sqlspec/_sql.py

Lines changed: 137 additions & 78 deletions
Large diffs are not rendered by default.

sqlspec/builder/_column.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from collections.abc import Iterable
8-
from typing import Any, Optional
8+
from typing import Any, Optional, cast
99

1010
from sqlglot import exp
1111

@@ -241,6 +241,10 @@ def desc(self) -> exp.Ordered:
241241
"""Create a DESC ordering expression."""
242242
return exp.Ordered(this=self._expression, desc=True)
243243

244+
def as_(self, alias: str) -> exp.Alias:
245+
"""Create an aliased expression."""
246+
return cast("exp.Alias", exp.alias_(self._expression, alias))
247+
244248
def __repr__(self) -> str:
245249
if self.table:
246250
return f"Column<{self.table}.{self.name}>"
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Expression wrapper classes for proper type annotations."""
2+
3+
from typing import cast
4+
5+
from sqlglot import exp
6+
7+
__all__ = ("AggregateExpression", "ConversionExpression", "FunctionExpression", "MathExpression", "StringExpression")
8+
9+
10+
class ExpressionWrapper:
11+
"""Base wrapper for SQLGlot expressions."""
12+
13+
def __init__(self, expression: exp.Expression) -> None:
14+
self._expression = expression
15+
16+
def as_(self, alias: str) -> exp.Alias:
17+
"""Create an aliased expression."""
18+
return cast("exp.Alias", exp.alias_(self._expression, alias))
19+
20+
@property
21+
def expression(self) -> exp.Expression:
22+
"""Get the underlying SQLGlot expression."""
23+
return self._expression
24+
25+
def __str__(self) -> str:
26+
return str(self._expression)
27+
28+
29+
class AggregateExpression(ExpressionWrapper):
30+
"""Aggregate functions like COUNT, SUM, AVG."""
31+
32+
33+
class FunctionExpression(ExpressionWrapper):
34+
"""General SQL functions."""
35+
36+
37+
class MathExpression(ExpressionWrapper):
38+
"""Mathematical functions like ROUND."""
39+
40+
41+
class StringExpression(ExpressionWrapper):
42+
"""String functions like UPPER, LOWER, LENGTH."""
43+
44+
45+
class ConversionExpression(ExpressionWrapper):
46+
"""Conversion functions like CAST, COALESCE."""

sqlspec/builder/_insert.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,7 @@ def do_update(self, **kwargs: Any) -> "Insert":
412412
# Create ON CONFLICT with proper structure
413413
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
414414
on_conflict = exp.OnConflict(
415-
conflict_keys=conflict_keys,
416-
action=exp.var("DO UPDATE"),
417-
expressions=set_expressions if set_expressions else None,
415+
conflict_keys=conflict_keys, action=exp.var("DO UPDATE"), expressions=set_expressions or None
418416
)
419417

420418
insert_expr.set("conflict", on_conflict)

sqlspec/builder/_update.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,26 @@ class Update(
4444
update_query = (
4545
Update()
4646
.table("users")
47-
.set(name="John Doe")
48-
.set(email="[email protected]")
47+
.set_(name="John Doe")
48+
.set_(email="[email protected]")
4949
.where("id = 1")
5050
)
5151
5252
update_query = (
53-
Update("users").set(name="John Doe").where("id = 1")
53+
Update("users").set_(name="John Doe").where("id = 1")
5454
)
5555
5656
update_query = (
5757
Update()
5858
.table("users")
59-
.set(status="active")
59+
.set_(status="active")
6060
.where_eq("id", 123)
6161
)
6262
6363
update_query = (
6464
Update()
6565
.table("users", "u")
66-
.set(name="Updated Name")
66+
.set_(name="Updated Name")
6767
.from_("profiles", "p")
6868
.where("u.id = p.user_id AND p.is_verified = true")
6969
)

sqlspec/builder/mixins/_order_limit_operations.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from sqlspec.exceptions import SQLBuilderError
1111

1212
if TYPE_CHECKING:
13+
from sqlspec.builder._column import Column
14+
from sqlspec.builder._expression_wrappers import ExpressionWrapper
15+
from sqlspec.builder.mixins._select_operations import Case
1316
from sqlspec.protocols import SQLBuilderProtocol
1417

1518
__all__ = ("LimitOffsetClauseMixin", "OrderByClauseMixin", "ReturningClauseMixin")
@@ -24,7 +27,7 @@ class OrderByClauseMixin:
2427
# Type annotation for PyRight - this will be provided by the base class
2528
_expression: Optional[exp.Expression]
2629

27-
def order_by(self, *items: Union[str, exp.Ordered], desc: bool = False) -> Self:
30+
def order_by(self, *items: Union[str, exp.Ordered, "Column"], desc: bool = False) -> Self:
2831
"""Add ORDER BY clause.
2932
3033
Args:
@@ -49,7 +52,13 @@ def order_by(self, *items: Union[str, exp.Ordered], desc: bool = False) -> Self:
4952
if desc:
5053
order_item = order_item.desc()
5154
else:
52-
order_item = item
55+
# Extract expression from Column objects or use as-is for sqlglot expressions
56+
from sqlspec._sql import SQLFactory
57+
58+
extracted_item = SQLFactory._extract_expression(item)
59+
order_item = extracted_item
60+
if desc and not isinstance(item, exp.Ordered):
61+
order_item = order_item.desc()
5362
current_expr = current_expr.order_by(order_item, copy=False)
5463
builder._expression = current_expr
5564
return cast("Self", builder)
@@ -111,7 +120,7 @@ class ReturningClauseMixin:
111120
# Type annotation for PyRight - this will be provided by the base class
112121
_expression: Optional[exp.Expression]
113122

114-
def returning(self, *columns: Union[str, exp.Expression]) -> Self:
123+
def returning(self, *columns: Union[str, exp.Expression, "Column", "ExpressionWrapper", "Case"]) -> Self:
115124
"""Add RETURNING clause to the statement.
116125
117126
Args:
@@ -130,6 +139,9 @@ def returning(self, *columns: Union[str, exp.Expression]) -> Self:
130139
if not isinstance(self._expression, valid_types):
131140
msg = "RETURNING is only supported for INSERT, UPDATE, and DELETE statements."
132141
raise SQLBuilderError(msg)
133-
returning_exprs = [exp.column(c) if isinstance(c, str) else c for c in columns]
142+
# Extract expressions from various wrapper types
143+
from sqlspec._sql import SQLFactory
144+
145+
returning_exprs = [SQLFactory._extract_expression(c) for c in columns]
134146
self._expression.set("returning", exp.Returning(expressions=returning_exprs))
135147
return self

sqlspec/builder/mixins/_select_operations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def when(self, condition: Union[str, exp.Expression], value: Union[str, exp.Expr
858858
from sqlspec._sql import SQLFactory
859859

860860
cond_expr = exp.maybe_parse(condition) or exp.column(condition) if isinstance(condition, str) else condition
861-
val_expr = SQLFactory._to_literal(value)
861+
val_expr = SQLFactory._to_expression(value)
862862

863863
# SQLGlot uses exp.If for CASE WHEN clauses, not exp.When
864864
when_clause = exp.If(this=cond_expr, true=val_expr)
@@ -876,7 +876,7 @@ def else_(self, value: Union[str, exp.Expression, Any]) -> Self:
876876
"""
877877
from sqlspec._sql import SQLFactory
878878

879-
self._default = SQLFactory._to_literal(value)
879+
self._default = SQLFactory._to_expression(value)
880880
return self
881881

882882
def end(self) -> Self:

sqlspec/builder/mixins/_update_operations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ def set(self, *args: Any, **kwargs: Any) -> Self:
111111
"""Set columns and values for the UPDATE statement.
112112
113113
Supports:
114-
- set(column, value)
115-
- set(mapping)
116-
- set(**kwargs)
117-
- set(mapping, **kwargs)
114+
- set_(column, value)
115+
- set_(mapping)
116+
- set_(**kwargs)
117+
- set_(mapping, **kwargs)
118118
119119
Args:
120120
*args: Either (column, value) or a mapping.

0 commit comments

Comments
 (0)