Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
name = "sqlspec"
readme = "README.md"
requires-python = ">=3.9, <4.0"
version = "0.16.1"
version = "0.16.2"

[project.urls]
Discord = "https://discord.gg/litestar"
Expand Down
10 changes: 6 additions & 4 deletions sqlspec/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def __call__(self, statement: str, dialect: DialectType = None) -> "Any":
# ===================
# Statement Builders
# ===================
def select(self, *columns_or_sql: Union[str, exp.Expression, Column], dialect: DialectType = None) -> "Select":
def select(
self, *columns_or_sql: Union[str, exp.Expression, Column, "SQL"], dialect: DialectType = None
) -> "Select":
builder_dialect = dialect or self.dialect
if len(columns_or_sql) == 1 and isinstance(columns_or_sql[0], str):
sql_candidate = columns_or_sql[0].strip()
Expand Down Expand Up @@ -1531,7 +1533,7 @@ def order_by(self, *columns: Union[str, exp.Expression]) -> "WindowFunctionBuild
self._order_by_cols.append(exp.Ordered(this=col, desc=False))
return self

def as_(self, alias: str) -> exp.Expression:
def as_(self, alias: str) -> exp.Alias:
"""Complete the window function with an alias.

Args:
Expand Down Expand Up @@ -1755,11 +1757,11 @@ def on(self, condition: Union[str, exp.Expression]) -> exp.Expression:
if isinstance(self._table, str):
table_expr = exp.to_table(self._table)
if self._alias:
table_expr = cast("exp.Expression", exp.alias_(table_expr, self._alias))
table_expr = exp.alias_(table_expr, self._alias)
else:
table_expr = self._table
if self._alias:
table_expr = cast("exp.Expression", exp.alias_(table_expr, self._alias))
table_expr = exp.alias_(table_expr, self._alias)

# Create the appropriate join type using same pattern as existing JoinClauseMixin
if self._join_type == "INNER JOIN":
Expand Down
189 changes: 177 additions & 12 deletions sqlspec/builder/_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,29 @@ def values(self, *values: Any, **kwargs: Any) -> "Self":
for i, value in enumerate(values):
if isinstance(value, exp.Expression):
value_placeholders.append(value)
elif hasattr(value, "expression") and hasattr(value, "sql"):
# Handle SQL objects (from sql.raw with parameters)
expression = getattr(value, "expression", None)
if expression is not None and isinstance(expression, exp.Expression):
# Merge parameters from SQL object into builder
if hasattr(value, "parameters"):
sql_parameters = getattr(value, "parameters", {})
for param_name, param_value in sql_parameters.items():
self.add_parameter(param_value, name=param_name)
value_placeholders.append(expression)
else:
# If expression is None, fall back to parsing the raw SQL
sql_text = getattr(value, "sql", "")
# Merge parameters even when parsing raw SQL
if hasattr(value, "parameters"):
sql_parameters = getattr(value, "parameters", {})
for param_name, param_value in sql_parameters.items():
self.add_parameter(param_value, name=param_name)
# Check if sql_text is callable (like Expression.sql method)
if callable(sql_text):
sql_text = str(value)
value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
value_placeholders.append(value_expr)
else:
if self._columns and i < len(self._columns):
column_str = str(self._columns[i])
Expand Down Expand Up @@ -228,29 +251,171 @@ def values_from_dicts(self, data: "Sequence[Mapping[str, Any]]") -> "Self":

return self

def on_conflict_do_nothing(self) -> "Self":
"""Adds an ON CONFLICT DO NOTHING clause (PostgreSQL syntax).
def on_conflict(self, *columns: str) -> "ConflictBuilder":
"""Adds an ON CONFLICT clause with specified columns.

Args:
*columns: Column names that define the conflict. If no columns provided,
creates an ON CONFLICT without specific columns (catches all conflicts).

Returns:
A ConflictBuilder instance for chaining conflict resolution methods.

Example:
```python
# ON CONFLICT (id) DO NOTHING
sql.insert("users").values(id=1, name="John").on_conflict(
"id"
).do_nothing()

# ON CONFLICT (email, username) DO UPDATE SET updated_at = NOW()
sql.insert("users").values(...).on_conflict(
"email", "username"
).do_update(updated_at=sql.raw("NOW()"))

# ON CONFLICT DO NOTHING (catches all conflicts)
sql.insert("users").values(...).on_conflict().do_nothing()
```
"""
return ConflictBuilder(self, columns)

def on_conflict_do_nothing(self, *columns: str) -> "Insert":
"""Adds an ON CONFLICT DO NOTHING clause (convenience method).

This is used to ignore rows that would cause a conflict.
Args:
*columns: Column names that define the conflict. If no columns provided,
creates an ON CONFLICT without specific columns.

Returns:
The current builder instance for method chaining.

Note:
This is PostgreSQL-specific syntax. Different databases have different syntax.
For a more general solution, you might need dialect-specific handling.
This is a convenience method. For more control, use on_conflict().do_nothing().
"""
insert_expr = self._get_insert_expression()
insert_expr.set("on", exp.OnConflict(this=None, expressions=[]))
return self
return self.on_conflict(*columns).do_nothing()

def on_duplicate_key_update(self, **_: Any) -> "Self":
"""Adds an ON DUPLICATE KEY UPDATE clause (MySQL syntax).
def on_duplicate_key_update(self, **kwargs: Any) -> "Insert":
"""Adds conflict resolution using the ON CONFLICT syntax (cross-database compatible).

Args:
**_: Column-value pairs to update on duplicate key.
**kwargs: Column-value pairs to update on conflict.

Returns:
The current builder instance for method chaining.

Note:
This method uses PostgreSQL-style ON CONFLICT syntax but SQLGlot will
transpile it to the appropriate syntax for each database (MySQL's
ON DUPLICATE KEY UPDATE, etc.).
"""
return self
if not kwargs:
return self
return self.on_conflict().do_update(**kwargs)


class ConflictBuilder:
"""Builder for ON CONFLICT clauses in INSERT statements.

This builder provides a fluent interface for constructing conflict resolution
clauses using PostgreSQL-style syntax, which SQLGlot can transpile to other dialects.
"""

__slots__ = ("_columns", "_insert_builder")

def __init__(self, insert_builder: "Insert", columns: tuple[str, ...]) -> None:
"""Initialize ConflictBuilder.

Args:
insert_builder: The parent Insert builder
columns: Column names that define the conflict
"""
self._insert_builder = insert_builder
self._columns = columns

def do_nothing(self) -> "Insert":
"""Add DO NOTHING conflict resolution.

Returns:
The parent Insert builder for method chaining.

Example:
```python
sql.insert("users").values(id=1, name="John").on_conflict(
"id"
).do_nothing()
```
"""
insert_expr = self._insert_builder._get_insert_expression()

# Create ON CONFLICT with proper structure
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
on_conflict = exp.OnConflict(conflict_keys=conflict_keys, action=exp.var("DO NOTHING"))

insert_expr.set("conflict", on_conflict)
return self._insert_builder

def do_update(self, **kwargs: Any) -> "Insert":
"""Add DO UPDATE conflict resolution with SET clauses.

Args:
**kwargs: Column-value pairs to update on conflict.

Returns:
The parent Insert builder for method chaining.

Example:
```python
sql.insert("users").values(id=1, name="John").on_conflict(
"id"
).do_update(
name="Updated Name", updated_at=sql.raw("NOW()")
)
```
"""
insert_expr = self._insert_builder._get_insert_expression()

# Create SET expressions for the UPDATE
set_expressions = []
for col, val in kwargs.items():
if hasattr(val, "expression") and hasattr(val, "sql"):
# Handle SQL objects (from sql.raw with parameters)
expression = getattr(val, "expression", None)
if expression is not None and isinstance(expression, exp.Expression):
# Merge parameters from SQL object into builder
if hasattr(val, "parameters"):
sql_parameters = getattr(val, "parameters", {})
for param_name, param_value in sql_parameters.items():
self._insert_builder.add_parameter(param_value, name=param_name)
value_expr = expression
else:
# If expression is None, fall back to parsing the raw SQL
sql_text = getattr(val, "sql", "")
# Merge parameters even when parsing raw SQL
if hasattr(val, "parameters"):
sql_parameters = getattr(val, "parameters", {})
for param_name, param_value in sql_parameters.items():
self._insert_builder.add_parameter(param_value, name=param_name)
# Check if sql_text is callable (like Expression.sql method)
if callable(sql_text):
sql_text = str(val)
value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
elif isinstance(val, exp.Expression):
value_expr = val
else:
# Create parameter for regular values
param_name = self._insert_builder._generate_unique_parameter_name(col)
_, param_name = self._insert_builder.add_parameter(val, name=param_name)
value_expr = exp.Placeholder(this=param_name)

set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))

# Create ON CONFLICT with proper structure
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
on_conflict = exp.OnConflict(
conflict_keys=conflict_keys,
action=exp.var("DO UPDATE"),
expressions=set_expressions if set_expressions else None,
)

insert_expr.set("conflict", on_conflict)
return self._insert_builder
28 changes: 26 additions & 2 deletions sqlspec/builder/_parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from sqlspec.utils.type_guards import has_expression_attr, has_parameter_builder


def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> exp.Expression:
def parse_column_expression(
column_input: Union[str, exp.Expression, Any], builder: Optional[Any] = None
) -> exp.Expression:
"""Parse a column input that might be a complex expression.

Handles cases like:
Expand All @@ -22,16 +24,38 @@ def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> ex
- Function calls: "MAX(price)" -> Max(this=Column(price))
- Complex expressions: "CASE WHEN ... END" -> Case(...)
- Custom Column objects from our builder
- SQL objects with raw SQL expressions

Args:
column_input: String, SQLGlot expression, or Column object
column_input: String, SQLGlot expression, SQL object, or Column object
builder: Optional builder instance for parameter merging

Returns:
exp.Expression: Parsed SQLGlot expression
"""
if isinstance(column_input, exp.Expression):
return column_input

# Handle SQL objects (from sql.raw with parameters)
if hasattr(column_input, "expression") and hasattr(column_input, "sql"):
# This is likely a SQL object
expression = getattr(column_input, "expression", None)
if expression is not None and isinstance(expression, exp.Expression):
# Merge parameters from SQL object into builder if available
if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"):
sql_parameters = getattr(column_input, "parameters", {})
for param_name, param_value in sql_parameters.items():
builder.add_parameter(param_value, name=param_name)
return cast("exp.Expression", expression)
# If expression is None, fall back to parsing the raw SQL
sql_text = getattr(column_input, "sql", "")
# Merge parameters even when parsing raw SQL
if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"):
sql_parameters = getattr(column_input, "parameters", {})
for param_name, param_value in sql_parameters.items():
builder.add_parameter(param_value, name=param_name)
return exp.maybe_parse(sql_text) or exp.column(str(sql_text))

if has_expression_attr(column_input):
try:
attr_value = column_input._expression
Expand Down
39 changes: 33 additions & 6 deletions sqlspec/builder/mixins/_join_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlspec.utils.type_guards import has_query_builder_parameters

if TYPE_CHECKING:
from sqlspec.core.statement import SQL
from sqlspec.protocols import SQLBuilderProtocol

__all__ = ("JoinClauseMixin",)
Expand All @@ -26,7 +27,7 @@ class JoinClauseMixin:
def join(
self,
table: Union[str, exp.Expression, Any],
on: Optional[Union[str, exp.Expression]] = None,
on: Optional[Union[str, exp.Expression, "SQL"]] = None,
alias: Optional[str] = None,
join_type: str = "INNER",
) -> Self:
Expand Down Expand Up @@ -56,7 +57,33 @@ def join(
table_expr = table
on_expr: Optional[exp.Expression] = None
if on is not None:
on_expr = exp.condition(on) if isinstance(on, str) else on
if isinstance(on, str):
on_expr = exp.condition(on)
elif hasattr(on, "expression") and hasattr(on, "sql"):
# Handle SQL objects (from sql.raw with parameters)
expression = getattr(on, "expression", None)
if expression is not None and isinstance(expression, exp.Expression):
# Merge parameters from SQL object into builder
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
sql_parameters = getattr(on, "parameters", {})
for param_name, param_value in sql_parameters.items():
builder.add_parameter(param_value, name=param_name)
on_expr = expression
else:
# If expression is None, fall back to parsing the raw SQL
sql_text = getattr(on, "sql", "")
# Merge parameters even when parsing raw SQL
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
sql_parameters = getattr(on, "parameters", {})
for param_name, param_value in sql_parameters.items():
builder.add_parameter(param_value, name=param_name)
on_expr = exp.maybe_parse(sql_text) or exp.condition(str(sql_text))
# For other types (should be exp.Expression)
elif isinstance(on, exp.Expression):
on_expr = on
else:
# Last resort - convert to string and parse
on_expr = exp.condition(str(on))
join_type_upper = join_type.upper()
if join_type_upper == "INNER":
join_expr = exp.Join(this=table_expr, on=on_expr)
Expand All @@ -73,22 +100,22 @@ def join(
return cast("Self", builder)

def inner_join(
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
) -> Self:
return self.join(table, on, alias, "INNER")

def left_join(
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
) -> Self:
return self.join(table, on, alias, "LEFT")

def right_join(
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
) -> Self:
return self.join(table, on, alias, "RIGHT")

def full_join(
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
) -> Self:
return self.join(table, on, alias, "FULL")

Expand Down
Loading