Skip to content

Commit 9f6a8ce

Browse files
authored
fix: Imrpove builder type hints and enhance SQL operations (#51)
Improve type hints for `as_` and `select` methods, and ensure correct handling of `upsert` and `merge` operations with SQL objects. This enhances type safety and functionality in SQL operations.
1 parent a760bc4 commit 9f6a8ce

File tree

12 files changed

+1407
-84
lines changed

12 files changed

+1407
-84
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
1313
name = "sqlspec"
1414
readme = "README.md"
1515
requires-python = ">=3.9, <4.0"
16-
version = "0.16.1"
16+
version = "0.16.2"
1717

1818
[project.urls]
1919
Discord = "https://discord.gg/litestar"

sqlspec/_sql.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ def __call__(self, statement: str, dialect: DialectType = None) -> "Any":
178178
# ===================
179179
# Statement Builders
180180
# ===================
181-
def select(self, *columns_or_sql: Union[str, exp.Expression, Column], dialect: DialectType = None) -> "Select":
181+
def select(
182+
self, *columns_or_sql: Union[str, exp.Expression, Column, "SQL"], dialect: DialectType = None
183+
) -> "Select":
182184
builder_dialect = dialect or self.dialect
183185
if len(columns_or_sql) == 1 and isinstance(columns_or_sql[0], str):
184186
sql_candidate = columns_or_sql[0].strip()
@@ -1531,7 +1533,7 @@ def order_by(self, *columns: Union[str, exp.Expression]) -> "WindowFunctionBuild
15311533
self._order_by_cols.append(exp.Ordered(this=col, desc=False))
15321534
return self
15331535

1534-
def as_(self, alias: str) -> exp.Expression:
1536+
def as_(self, alias: str) -> exp.Alias:
15351537
"""Complete the window function with an alias.
15361538
15371539
Args:
@@ -1755,11 +1757,11 @@ def on(self, condition: Union[str, exp.Expression]) -> exp.Expression:
17551757
if isinstance(self._table, str):
17561758
table_expr = exp.to_table(self._table)
17571759
if self._alias:
1758-
table_expr = cast("exp.Expression", exp.alias_(table_expr, self._alias))
1760+
table_expr = exp.alias_(table_expr, self._alias)
17591761
else:
17601762
table_expr = self._table
17611763
if self._alias:
1762-
table_expr = cast("exp.Expression", exp.alias_(table_expr, self._alias))
1764+
table_expr = exp.alias_(table_expr, self._alias)
17631765

17641766
# Create the appropriate join type using same pattern as existing JoinClauseMixin
17651767
if self._join_type == "INNER JOIN":

sqlspec/builder/_insert.py

Lines changed: 177 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,29 @@ def values(self, *values: Any, **kwargs: Any) -> "Self":
142142
for i, value in enumerate(values):
143143
if isinstance(value, exp.Expression):
144144
value_placeholders.append(value)
145+
elif hasattr(value, "expression") and hasattr(value, "sql"):
146+
# Handle SQL objects (from sql.raw with parameters)
147+
expression = getattr(value, "expression", None)
148+
if expression is not None and isinstance(expression, exp.Expression):
149+
# Merge parameters from SQL object into builder
150+
if hasattr(value, "parameters"):
151+
sql_parameters = getattr(value, "parameters", {})
152+
for param_name, param_value in sql_parameters.items():
153+
self.add_parameter(param_value, name=param_name)
154+
value_placeholders.append(expression)
155+
else:
156+
# If expression is None, fall back to parsing the raw SQL
157+
sql_text = getattr(value, "sql", "")
158+
# Merge parameters even when parsing raw SQL
159+
if hasattr(value, "parameters"):
160+
sql_parameters = getattr(value, "parameters", {})
161+
for param_name, param_value in sql_parameters.items():
162+
self.add_parameter(param_value, name=param_name)
163+
# Check if sql_text is callable (like Expression.sql method)
164+
if callable(sql_text):
165+
sql_text = str(value)
166+
value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
167+
value_placeholders.append(value_expr)
145168
else:
146169
if self._columns and i < len(self._columns):
147170
column_str = str(self._columns[i])
@@ -228,29 +251,171 @@ def values_from_dicts(self, data: "Sequence[Mapping[str, Any]]") -> "Self":
228251

229252
return self
230253

231-
def on_conflict_do_nothing(self) -> "Self":
232-
"""Adds an ON CONFLICT DO NOTHING clause (PostgreSQL syntax).
254+
def on_conflict(self, *columns: str) -> "ConflictBuilder":
255+
"""Adds an ON CONFLICT clause with specified columns.
256+
257+
Args:
258+
*columns: Column names that define the conflict. If no columns provided,
259+
creates an ON CONFLICT without specific columns (catches all conflicts).
260+
261+
Returns:
262+
A ConflictBuilder instance for chaining conflict resolution methods.
263+
264+
Example:
265+
```python
266+
# ON CONFLICT (id) DO NOTHING
267+
sql.insert("users").values(id=1, name="John").on_conflict(
268+
"id"
269+
).do_nothing()
270+
271+
# ON CONFLICT (email, username) DO UPDATE SET updated_at = NOW()
272+
sql.insert("users").values(...).on_conflict(
273+
"email", "username"
274+
).do_update(updated_at=sql.raw("NOW()"))
275+
276+
# ON CONFLICT DO NOTHING (catches all conflicts)
277+
sql.insert("users").values(...).on_conflict().do_nothing()
278+
```
279+
"""
280+
return ConflictBuilder(self, columns)
281+
282+
def on_conflict_do_nothing(self, *columns: str) -> "Insert":
283+
"""Adds an ON CONFLICT DO NOTHING clause (convenience method).
233284
234-
This is used to ignore rows that would cause a conflict.
285+
Args:
286+
*columns: Column names that define the conflict. If no columns provided,
287+
creates an ON CONFLICT without specific columns.
235288
236289
Returns:
237290
The current builder instance for method chaining.
238291
239292
Note:
240-
This is PostgreSQL-specific syntax. Different databases have different syntax.
241-
For a more general solution, you might need dialect-specific handling.
293+
This is a convenience method. For more control, use on_conflict().do_nothing().
242294
"""
243-
insert_expr = self._get_insert_expression()
244-
insert_expr.set("on", exp.OnConflict(this=None, expressions=[]))
245-
return self
295+
return self.on_conflict(*columns).do_nothing()
246296

247-
def on_duplicate_key_update(self, **_: Any) -> "Self":
248-
"""Adds an ON DUPLICATE KEY UPDATE clause (MySQL syntax).
297+
def on_duplicate_key_update(self, **kwargs: Any) -> "Insert":
298+
"""Adds conflict resolution using the ON CONFLICT syntax (cross-database compatible).
249299
250300
Args:
251-
**_: Column-value pairs to update on duplicate key.
301+
**kwargs: Column-value pairs to update on conflict.
252302
253303
Returns:
254304
The current builder instance for method chaining.
305+
306+
Note:
307+
This method uses PostgreSQL-style ON CONFLICT syntax but SQLGlot will
308+
transpile it to the appropriate syntax for each database (MySQL's
309+
ON DUPLICATE KEY UPDATE, etc.).
255310
"""
256-
return self
311+
if not kwargs:
312+
return self
313+
return self.on_conflict().do_update(**kwargs)
314+
315+
316+
class ConflictBuilder:
317+
"""Builder for ON CONFLICT clauses in INSERT statements.
318+
319+
This builder provides a fluent interface for constructing conflict resolution
320+
clauses using PostgreSQL-style syntax, which SQLGlot can transpile to other dialects.
321+
"""
322+
323+
__slots__ = ("_columns", "_insert_builder")
324+
325+
def __init__(self, insert_builder: "Insert", columns: tuple[str, ...]) -> None:
326+
"""Initialize ConflictBuilder.
327+
328+
Args:
329+
insert_builder: The parent Insert builder
330+
columns: Column names that define the conflict
331+
"""
332+
self._insert_builder = insert_builder
333+
self._columns = columns
334+
335+
def do_nothing(self) -> "Insert":
336+
"""Add DO NOTHING conflict resolution.
337+
338+
Returns:
339+
The parent Insert builder for method chaining.
340+
341+
Example:
342+
```python
343+
sql.insert("users").values(id=1, name="John").on_conflict(
344+
"id"
345+
).do_nothing()
346+
```
347+
"""
348+
insert_expr = self._insert_builder._get_insert_expression()
349+
350+
# Create ON CONFLICT with proper structure
351+
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
352+
on_conflict = exp.OnConflict(conflict_keys=conflict_keys, action=exp.var("DO NOTHING"))
353+
354+
insert_expr.set("conflict", on_conflict)
355+
return self._insert_builder
356+
357+
def do_update(self, **kwargs: Any) -> "Insert":
358+
"""Add DO UPDATE conflict resolution with SET clauses.
359+
360+
Args:
361+
**kwargs: Column-value pairs to update on conflict.
362+
363+
Returns:
364+
The parent Insert builder for method chaining.
365+
366+
Example:
367+
```python
368+
sql.insert("users").values(id=1, name="John").on_conflict(
369+
"id"
370+
).do_update(
371+
name="Updated Name", updated_at=sql.raw("NOW()")
372+
)
373+
```
374+
"""
375+
insert_expr = self._insert_builder._get_insert_expression()
376+
377+
# Create SET expressions for the UPDATE
378+
set_expressions = []
379+
for col, val in kwargs.items():
380+
if hasattr(val, "expression") and hasattr(val, "sql"):
381+
# Handle SQL objects (from sql.raw with parameters)
382+
expression = getattr(val, "expression", None)
383+
if expression is not None and isinstance(expression, exp.Expression):
384+
# Merge parameters from SQL object into builder
385+
if hasattr(val, "parameters"):
386+
sql_parameters = getattr(val, "parameters", {})
387+
for param_name, param_value in sql_parameters.items():
388+
self._insert_builder.add_parameter(param_value, name=param_name)
389+
value_expr = expression
390+
else:
391+
# If expression is None, fall back to parsing the raw SQL
392+
sql_text = getattr(val, "sql", "")
393+
# Merge parameters even when parsing raw SQL
394+
if hasattr(val, "parameters"):
395+
sql_parameters = getattr(val, "parameters", {})
396+
for param_name, param_value in sql_parameters.items():
397+
self._insert_builder.add_parameter(param_value, name=param_name)
398+
# Check if sql_text is callable (like Expression.sql method)
399+
if callable(sql_text):
400+
sql_text = str(val)
401+
value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
402+
elif isinstance(val, exp.Expression):
403+
value_expr = val
404+
else:
405+
# Create parameter for regular values
406+
param_name = self._insert_builder._generate_unique_parameter_name(col)
407+
_, param_name = self._insert_builder.add_parameter(val, name=param_name)
408+
value_expr = exp.Placeholder(this=param_name)
409+
410+
set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
411+
412+
# Create ON CONFLICT with proper structure
413+
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
414+
on_conflict = exp.OnConflict(
415+
conflict_keys=conflict_keys,
416+
action=exp.var("DO UPDATE"),
417+
expressions=set_expressions if set_expressions else None,
418+
)
419+
420+
insert_expr.set("conflict", on_conflict)
421+
return self._insert_builder

sqlspec/builder/_parsing_utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from sqlspec.utils.type_guards import has_expression_attr, has_parameter_builder
1313

1414

15-
def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> exp.Expression:
15+
def parse_column_expression(
16+
column_input: Union[str, exp.Expression, Any], builder: Optional[Any] = None
17+
) -> exp.Expression:
1618
"""Parse a column input that might be a complex expression.
1719
1820
Handles cases like:
@@ -22,16 +24,38 @@ def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> ex
2224
- Function calls: "MAX(price)" -> Max(this=Column(price))
2325
- Complex expressions: "CASE WHEN ... END" -> Case(...)
2426
- Custom Column objects from our builder
27+
- SQL objects with raw SQL expressions
2528
2629
Args:
27-
column_input: String, SQLGlot expression, or Column object
30+
column_input: String, SQLGlot expression, SQL object, or Column object
31+
builder: Optional builder instance for parameter merging
2832
2933
Returns:
3034
exp.Expression: Parsed SQLGlot expression
3135
"""
3236
if isinstance(column_input, exp.Expression):
3337
return column_input
3438

39+
# Handle SQL objects (from sql.raw with parameters)
40+
if hasattr(column_input, "expression") and hasattr(column_input, "sql"):
41+
# This is likely a SQL object
42+
expression = getattr(column_input, "expression", None)
43+
if expression is not None and isinstance(expression, exp.Expression):
44+
# Merge parameters from SQL object into builder if available
45+
if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"):
46+
sql_parameters = getattr(column_input, "parameters", {})
47+
for param_name, param_value in sql_parameters.items():
48+
builder.add_parameter(param_value, name=param_name)
49+
return cast("exp.Expression", expression)
50+
# If expression is None, fall back to parsing the raw SQL
51+
sql_text = getattr(column_input, "sql", "")
52+
# Merge parameters even when parsing raw SQL
53+
if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"):
54+
sql_parameters = getattr(column_input, "parameters", {})
55+
for param_name, param_value in sql_parameters.items():
56+
builder.add_parameter(param_value, name=param_name)
57+
return exp.maybe_parse(sql_text) or exp.column(str(sql_text))
58+
3559
if has_expression_attr(column_input):
3660
try:
3761
attr_value = column_input._expression

sqlspec/builder/mixins/_join_operations.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sqlspec.utils.type_guards import has_query_builder_parameters
1010

1111
if TYPE_CHECKING:
12+
from sqlspec.core.statement import SQL
1213
from sqlspec.protocols import SQLBuilderProtocol
1314

1415
__all__ = ("JoinClauseMixin",)
@@ -26,7 +27,7 @@ class JoinClauseMixin:
2627
def join(
2728
self,
2829
table: Union[str, exp.Expression, Any],
29-
on: Optional[Union[str, exp.Expression]] = None,
30+
on: Optional[Union[str, exp.Expression, "SQL"]] = None,
3031
alias: Optional[str] = None,
3132
join_type: str = "INNER",
3233
) -> Self:
@@ -56,7 +57,33 @@ def join(
5657
table_expr = table
5758
on_expr: Optional[exp.Expression] = None
5859
if on is not None:
59-
on_expr = exp.condition(on) if isinstance(on, str) else on
60+
if isinstance(on, str):
61+
on_expr = exp.condition(on)
62+
elif hasattr(on, "expression") and hasattr(on, "sql"):
63+
# Handle SQL objects (from sql.raw with parameters)
64+
expression = getattr(on, "expression", None)
65+
if expression is not None and isinstance(expression, exp.Expression):
66+
# Merge parameters from SQL object into builder
67+
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
68+
sql_parameters = getattr(on, "parameters", {})
69+
for param_name, param_value in sql_parameters.items():
70+
builder.add_parameter(param_value, name=param_name)
71+
on_expr = expression
72+
else:
73+
# If expression is None, fall back to parsing the raw SQL
74+
sql_text = getattr(on, "sql", "")
75+
# Merge parameters even when parsing raw SQL
76+
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
77+
sql_parameters = getattr(on, "parameters", {})
78+
for param_name, param_value in sql_parameters.items():
79+
builder.add_parameter(param_value, name=param_name)
80+
on_expr = exp.maybe_parse(sql_text) or exp.condition(str(sql_text))
81+
# For other types (should be exp.Expression)
82+
elif isinstance(on, exp.Expression):
83+
on_expr = on
84+
else:
85+
# Last resort - convert to string and parse
86+
on_expr = exp.condition(str(on))
6087
join_type_upper = join_type.upper()
6188
if join_type_upper == "INNER":
6289
join_expr = exp.Join(this=table_expr, on=on_expr)
@@ -73,22 +100,22 @@ def join(
73100
return cast("Self", builder)
74101

75102
def inner_join(
76-
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
103+
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
77104
) -> Self:
78105
return self.join(table, on, alias, "INNER")
79106

80107
def left_join(
81-
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
108+
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
82109
) -> Self:
83110
return self.join(table, on, alias, "LEFT")
84111

85112
def right_join(
86-
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
113+
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
87114
) -> Self:
88115
return self.join(table, on, alias, "RIGHT")
89116

90117
def full_join(
91-
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
118+
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
92119
) -> Self:
93120
return self.join(table, on, alias, "FULL")
94121

0 commit comments

Comments
 (0)