diff --git a/pyproject.toml b/pyproject.toml index bcb10a188..44dfc763d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "hello@litestar.dev" }] name = "sqlspec" readme = "README.md" requires-python = ">=3.9, <4.0" -version = "0.16.1" +version = "0.16.2" [project.urls] Discord = "https://discord.gg/litestar" diff --git a/sqlspec/_sql.py b/sqlspec/_sql.py index 9b1c63ee2..a598b6e2e 100644 --- a/sqlspec/_sql.py +++ b/sqlspec/_sql.py @@ -178,7 +178,9 @@ def __call__(self, statement: str, dialect: DialectType = None) -> "Any": # =================== # Statement Builders # =================== - def select(self, *columns_or_sql: Union[str, exp.Expression, Column], dialect: DialectType = None) -> "Select": + def select( + self, *columns_or_sql: Union[str, exp.Expression, Column, "SQL"], dialect: DialectType = None + ) -> "Select": builder_dialect = dialect or self.dialect if len(columns_or_sql) == 1 and isinstance(columns_or_sql[0], str): sql_candidate = columns_or_sql[0].strip() @@ -1531,7 +1533,7 @@ def order_by(self, *columns: Union[str, exp.Expression]) -> "WindowFunctionBuild self._order_by_cols.append(exp.Ordered(this=col, desc=False)) return self - def as_(self, alias: str) -> exp.Expression: + def as_(self, alias: str) -> exp.Alias: """Complete the window function with an alias. Args: @@ -1755,11 +1757,11 @@ def on(self, condition: Union[str, exp.Expression]) -> exp.Expression: if isinstance(self._table, str): table_expr = exp.to_table(self._table) if self._alias: - table_expr = cast("exp.Expression", exp.alias_(table_expr, self._alias)) + table_expr = exp.alias_(table_expr, self._alias) else: table_expr = self._table if self._alias: - table_expr = cast("exp.Expression", exp.alias_(table_expr, self._alias)) + table_expr = exp.alias_(table_expr, self._alias) # Create the appropriate join type using same pattern as existing JoinClauseMixin if self._join_type == "INNER JOIN": diff --git a/sqlspec/builder/_insert.py b/sqlspec/builder/_insert.py index 286360f48..3d88940ba 100644 --- a/sqlspec/builder/_insert.py +++ b/sqlspec/builder/_insert.py @@ -142,6 +142,29 @@ def values(self, *values: Any, **kwargs: Any) -> "Self": for i, value in enumerate(values): if isinstance(value, exp.Expression): value_placeholders.append(value) + elif hasattr(value, "expression") and hasattr(value, "sql"): + # Handle SQL objects (from sql.raw with parameters) + expression = getattr(value, "expression", None) + if expression is not None and isinstance(expression, exp.Expression): + # Merge parameters from SQL object into builder + if hasattr(value, "parameters"): + sql_parameters = getattr(value, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self.add_parameter(param_value, name=param_name) + value_placeholders.append(expression) + else: + # If expression is None, fall back to parsing the raw SQL + sql_text = getattr(value, "sql", "") + # Merge parameters even when parsing raw SQL + if hasattr(value, "parameters"): + sql_parameters = getattr(value, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self.add_parameter(param_value, name=param_name) + # Check if sql_text is callable (like Expression.sql method) + if callable(sql_text): + sql_text = str(value) + value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text)) + value_placeholders.append(value_expr) else: if self._columns and i < len(self._columns): column_str = str(self._columns[i]) @@ -228,29 +251,171 @@ def values_from_dicts(self, data: "Sequence[Mapping[str, Any]]") -> "Self": return self - def on_conflict_do_nothing(self) -> "Self": - """Adds an ON CONFLICT DO NOTHING clause (PostgreSQL syntax). + def on_conflict(self, *columns: str) -> "ConflictBuilder": + """Adds an ON CONFLICT clause with specified columns. + + Args: + *columns: Column names that define the conflict. If no columns provided, + creates an ON CONFLICT without specific columns (catches all conflicts). + + Returns: + A ConflictBuilder instance for chaining conflict resolution methods. + + Example: + ```python + # ON CONFLICT (id) DO NOTHING + sql.insert("users").values(id=1, name="John").on_conflict( + "id" + ).do_nothing() + + # ON CONFLICT (email, username) DO UPDATE SET updated_at = NOW() + sql.insert("users").values(...).on_conflict( + "email", "username" + ).do_update(updated_at=sql.raw("NOW()")) + + # ON CONFLICT DO NOTHING (catches all conflicts) + sql.insert("users").values(...).on_conflict().do_nothing() + ``` + """ + return ConflictBuilder(self, columns) + + def on_conflict_do_nothing(self, *columns: str) -> "Insert": + """Adds an ON CONFLICT DO NOTHING clause (convenience method). - This is used to ignore rows that would cause a conflict. + Args: + *columns: Column names that define the conflict. If no columns provided, + creates an ON CONFLICT without specific columns. Returns: The current builder instance for method chaining. Note: - This is PostgreSQL-specific syntax. Different databases have different syntax. - For a more general solution, you might need dialect-specific handling. + This is a convenience method. For more control, use on_conflict().do_nothing(). """ - insert_expr = self._get_insert_expression() - insert_expr.set("on", exp.OnConflict(this=None, expressions=[])) - return self + return self.on_conflict(*columns).do_nothing() - def on_duplicate_key_update(self, **_: Any) -> "Self": - """Adds an ON DUPLICATE KEY UPDATE clause (MySQL syntax). + def on_duplicate_key_update(self, **kwargs: Any) -> "Insert": + """Adds conflict resolution using the ON CONFLICT syntax (cross-database compatible). Args: - **_: Column-value pairs to update on duplicate key. + **kwargs: Column-value pairs to update on conflict. Returns: The current builder instance for method chaining. + + Note: + This method uses PostgreSQL-style ON CONFLICT syntax but SQLGlot will + transpile it to the appropriate syntax for each database (MySQL's + ON DUPLICATE KEY UPDATE, etc.). """ - return self + if not kwargs: + return self + return self.on_conflict().do_update(**kwargs) + + +class ConflictBuilder: + """Builder for ON CONFLICT clauses in INSERT statements. + + This builder provides a fluent interface for constructing conflict resolution + clauses using PostgreSQL-style syntax, which SQLGlot can transpile to other dialects. + """ + + __slots__ = ("_columns", "_insert_builder") + + def __init__(self, insert_builder: "Insert", columns: tuple[str, ...]) -> None: + """Initialize ConflictBuilder. + + Args: + insert_builder: The parent Insert builder + columns: Column names that define the conflict + """ + self._insert_builder = insert_builder + self._columns = columns + + def do_nothing(self) -> "Insert": + """Add DO NOTHING conflict resolution. + + Returns: + The parent Insert builder for method chaining. + + Example: + ```python + sql.insert("users").values(id=1, name="John").on_conflict( + "id" + ).do_nothing() + ``` + """ + insert_expr = self._insert_builder._get_insert_expression() + + # Create ON CONFLICT with proper structure + conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None + on_conflict = exp.OnConflict(conflict_keys=conflict_keys, action=exp.var("DO NOTHING")) + + insert_expr.set("conflict", on_conflict) + return self._insert_builder + + def do_update(self, **kwargs: Any) -> "Insert": + """Add DO UPDATE conflict resolution with SET clauses. + + Args: + **kwargs: Column-value pairs to update on conflict. + + Returns: + The parent Insert builder for method chaining. + + Example: + ```python + sql.insert("users").values(id=1, name="John").on_conflict( + "id" + ).do_update( + name="Updated Name", updated_at=sql.raw("NOW()") + ) + ``` + """ + insert_expr = self._insert_builder._get_insert_expression() + + # Create SET expressions for the UPDATE + set_expressions = [] + for col, val in kwargs.items(): + if hasattr(val, "expression") and hasattr(val, "sql"): + # Handle SQL objects (from sql.raw with parameters) + expression = getattr(val, "expression", None) + if expression is not None and isinstance(expression, exp.Expression): + # Merge parameters from SQL object into builder + if hasattr(val, "parameters"): + sql_parameters = getattr(val, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self._insert_builder.add_parameter(param_value, name=param_name) + value_expr = expression + else: + # If expression is None, fall back to parsing the raw SQL + sql_text = getattr(val, "sql", "") + # Merge parameters even when parsing raw SQL + if hasattr(val, "parameters"): + sql_parameters = getattr(val, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self._insert_builder.add_parameter(param_value, name=param_name) + # Check if sql_text is callable (like Expression.sql method) + if callable(sql_text): + sql_text = str(val) + value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text)) + elif isinstance(val, exp.Expression): + value_expr = val + else: + # Create parameter for regular values + param_name = self._insert_builder._generate_unique_parameter_name(col) + _, param_name = self._insert_builder.add_parameter(val, name=param_name) + value_expr = exp.Placeholder(this=param_name) + + set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr)) + + # Create ON CONFLICT with proper structure + conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None + on_conflict = exp.OnConflict( + conflict_keys=conflict_keys, + action=exp.var("DO UPDATE"), + expressions=set_expressions if set_expressions else None, + ) + + insert_expr.set("conflict", on_conflict) + return self._insert_builder diff --git a/sqlspec/builder/_parsing_utils.py b/sqlspec/builder/_parsing_utils.py index 6b18cba58..56e6493c3 100644 --- a/sqlspec/builder/_parsing_utils.py +++ b/sqlspec/builder/_parsing_utils.py @@ -12,7 +12,9 @@ from sqlspec.utils.type_guards import has_expression_attr, has_parameter_builder -def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> exp.Expression: +def parse_column_expression( + column_input: Union[str, exp.Expression, Any], builder: Optional[Any] = None +) -> exp.Expression: """Parse a column input that might be a complex expression. Handles cases like: @@ -22,9 +24,11 @@ def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> ex - Function calls: "MAX(price)" -> Max(this=Column(price)) - Complex expressions: "CASE WHEN ... END" -> Case(...) - Custom Column objects from our builder + - SQL objects with raw SQL expressions Args: - column_input: String, SQLGlot expression, or Column object + column_input: String, SQLGlot expression, SQL object, or Column object + builder: Optional builder instance for parameter merging Returns: exp.Expression: Parsed SQLGlot expression @@ -32,6 +36,26 @@ def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> ex if isinstance(column_input, exp.Expression): return column_input + # Handle SQL objects (from sql.raw with parameters) + if hasattr(column_input, "expression") and hasattr(column_input, "sql"): + # This is likely a SQL object + expression = getattr(column_input, "expression", None) + if expression is not None and isinstance(expression, exp.Expression): + # Merge parameters from SQL object into builder if available + if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"): + sql_parameters = getattr(column_input, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + builder.add_parameter(param_value, name=param_name) + return cast("exp.Expression", expression) + # If expression is None, fall back to parsing the raw SQL + sql_text = getattr(column_input, "sql", "") + # Merge parameters even when parsing raw SQL + if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"): + sql_parameters = getattr(column_input, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + builder.add_parameter(param_value, name=param_name) + return exp.maybe_parse(sql_text) or exp.column(str(sql_text)) + if has_expression_attr(column_input): try: attr_value = column_input._expression diff --git a/sqlspec/builder/mixins/_join_operations.py b/sqlspec/builder/mixins/_join_operations.py index 6b6d292a7..07b382cf3 100644 --- a/sqlspec/builder/mixins/_join_operations.py +++ b/sqlspec/builder/mixins/_join_operations.py @@ -9,6 +9,7 @@ from sqlspec.utils.type_guards import has_query_builder_parameters if TYPE_CHECKING: + from sqlspec.core.statement import SQL from sqlspec.protocols import SQLBuilderProtocol __all__ = ("JoinClauseMixin",) @@ -26,7 +27,7 @@ class JoinClauseMixin: def join( self, table: Union[str, exp.Expression, Any], - on: Optional[Union[str, exp.Expression]] = None, + on: Optional[Union[str, exp.Expression, "SQL"]] = None, alias: Optional[str] = None, join_type: str = "INNER", ) -> Self: @@ -56,7 +57,33 @@ def join( table_expr = table on_expr: Optional[exp.Expression] = None if on is not None: - on_expr = exp.condition(on) if isinstance(on, str) else on + if isinstance(on, str): + on_expr = exp.condition(on) + elif hasattr(on, "expression") and hasattr(on, "sql"): + # Handle SQL objects (from sql.raw with parameters) + expression = getattr(on, "expression", None) + if expression is not None and isinstance(expression, exp.Expression): + # Merge parameters from SQL object into builder + if hasattr(on, "parameters") and hasattr(builder, "add_parameter"): + sql_parameters = getattr(on, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + builder.add_parameter(param_value, name=param_name) + on_expr = expression + else: + # If expression is None, fall back to parsing the raw SQL + sql_text = getattr(on, "sql", "") + # Merge parameters even when parsing raw SQL + if hasattr(on, "parameters") and hasattr(builder, "add_parameter"): + sql_parameters = getattr(on, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + builder.add_parameter(param_value, name=param_name) + on_expr = exp.maybe_parse(sql_text) or exp.condition(str(sql_text)) + # For other types (should be exp.Expression) + elif isinstance(on, exp.Expression): + on_expr = on + else: + # Last resort - convert to string and parse + on_expr = exp.condition(str(on)) join_type_upper = join_type.upper() if join_type_upper == "INNER": join_expr = exp.Join(this=table_expr, on=on_expr) @@ -73,22 +100,22 @@ def join( return cast("Self", builder) def inner_join( - self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None + self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None ) -> Self: return self.join(table, on, alias, "INNER") def left_join( - self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None + self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None ) -> Self: return self.join(table, on, alias, "LEFT") def right_join( - self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None + self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None ) -> Self: return self.join(table, on, alias, "RIGHT") def full_join( - self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None + self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None ) -> Self: return self.join(table, on, alias, "FULL") diff --git a/sqlspec/builder/mixins/_merge_operations.py b/sqlspec/builder/mixins/_merge_operations.py index c0fa70df6..190637097 100644 --- a/sqlspec/builder/mixins/_merge_operations.py +++ b/sqlspec/builder/mixins/_merge_operations.py @@ -179,14 +179,23 @@ def _add_when_clause(self, when_clause: exp.When) -> None: whens.append("expressions", when_clause) def when_matched_then_update( - self, set_values: dict[str, Any], condition: Optional[Union[str, exp.Expression]] = None + self, + set_values: Optional[dict[str, Any]] = None, + condition: Optional[Union[str, exp.Expression]] = None, + **kwargs: Any, ) -> Self: """Define the UPDATE action for matched rows. + Supports: + - when_matched_then_update({"column": value}) + - when_matched_then_update(column=value, other_column=other_value) + - when_matched_then_update({"column": value}, other_column=other_value) + Args: set_values: A dictionary of column names and their new values to set. The values will be parameterized. condition: An optional additional condition for this specific action. + **kwargs: Column-value pairs to update on match. Raises: SQLBuilderError: If the condition type is unsupported. @@ -194,14 +203,48 @@ def when_matched_then_update( Returns: The current builder instance for method chaining. """ + # Combine set_values dict and kwargs + all_values = dict(set_values or {}, **kwargs) + + if not all_values: + msg = "No update values provided. Use set_values dict or kwargs." + raise SQLBuilderError(msg) + update_expressions: list[exp.EQ] = [] - for col, val in set_values.items(): - column_name = col if isinstance(col, str) else str(col) - if "." in column_name: - column_name = column_name.split(".")[-1] - param_name = self._generate_unique_parameter_name(column_name) - param_name = self.add_parameter(val, name=param_name)[1] - update_expressions.append(exp.EQ(this=exp.column(col), expression=exp.var(param_name))) + for col, val in all_values.items(): + if hasattr(val, "expression") and hasattr(val, "sql"): + # Handle SQL objects (from sql.raw with parameters) + expression = getattr(val, "expression", None) + if expression is not None and isinstance(expression, exp.Expression): + # Merge parameters from SQL object into builder + if hasattr(val, "parameters"): + sql_parameters = getattr(val, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self.add_parameter(param_value, name=param_name) + value_expr = expression + else: + # If expression is None, fall back to parsing the raw SQL + sql_text = getattr(val, "sql", "") + # Merge parameters even when parsing raw SQL + if hasattr(val, "parameters"): + sql_parameters = getattr(val, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self.add_parameter(param_value, name=param_name) + # Check if sql_text is callable (like Expression.sql method) + if callable(sql_text): + sql_text = str(val) + value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text)) + elif isinstance(val, exp.Expression): + value_expr = val + else: + column_name = col if isinstance(col, str) else str(col) + if "." in column_name: + column_name = column_name.split(".")[-1] + param_name = self._generate_unique_parameter_name(column_name) + param_name = self.add_parameter(val, name=param_name)[1] + value_expr = exp.Placeholder(this=param_name) + + update_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr)) when_args: dict[str, Any] = {"matched": True, "then": exp.Update(expressions=update_expressions)} @@ -386,15 +429,24 @@ def _add_when_clause(self, when_clause: exp.When) -> None: raise NotImplementedError(msg) def when_not_matched_by_source_then_update( - self, set_values: dict[str, Any], condition: Optional[Union[str, exp.Expression]] = None + self, + set_values: Optional[dict[str, Any]] = None, + condition: Optional[Union[str, exp.Expression]] = None, + **kwargs: Any, ) -> Self: """Define the UPDATE action for rows not matched by source. This is useful for handling rows that exist in the target but not in the source. + Supports: + - when_not_matched_by_source_then_update({"column": value}) + - when_not_matched_by_source_then_update(column=value, other_column=other_value) + - when_not_matched_by_source_then_update({"column": value}, other_column=other_value) + Args: set_values: A dictionary of column names and their new values to set. condition: An optional additional condition for this specific action. + **kwargs: Column-value pairs to update when not matched by source. Raises: SQLBuilderError: If the condition type is unsupported. @@ -402,14 +454,48 @@ def when_not_matched_by_source_then_update( Returns: The current builder instance for method chaining. """ + # Combine set_values dict and kwargs + all_values = dict(set_values or {}, **kwargs) + + if not all_values: + msg = "No update values provided. Use set_values dict or kwargs." + raise SQLBuilderError(msg) + update_expressions: list[exp.EQ] = [] - for col, val in set_values.items(): - column_name = col if isinstance(col, str) else str(col) - if "." in column_name: - column_name = column_name.split(".")[-1] - param_name = self._generate_unique_parameter_name(column_name) - param_name = self.add_parameter(val, name=param_name)[1] - update_expressions.append(exp.EQ(this=exp.column(col), expression=exp.var(param_name))) + for col, val in all_values.items(): + if hasattr(val, "expression") and hasattr(val, "sql"): + # Handle SQL objects (from sql.raw with parameters) + expression = getattr(val, "expression", None) + if expression is not None and isinstance(expression, exp.Expression): + # Merge parameters from SQL object into builder + if hasattr(val, "parameters"): + sql_parameters = getattr(val, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self.add_parameter(param_value, name=param_name) + value_expr = expression + else: + # If expression is None, fall back to parsing the raw SQL + sql_text = getattr(val, "sql", "") + # Merge parameters even when parsing raw SQL + if hasattr(val, "parameters"): + sql_parameters = getattr(val, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self.add_parameter(param_value, name=param_name) + # Check if sql_text is callable (like Expression.sql method) + if callable(sql_text): + sql_text = str(val) + value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text)) + elif isinstance(val, exp.Expression): + value_expr = val + else: + column_name = col if isinstance(col, str) else str(col) + if "." in column_name: + column_name = column_name.split(".")[-1] + param_name = self._generate_unique_parameter_name(column_name) + param_name = self.add_parameter(val, name=param_name)[1] + value_expr = exp.Placeholder(this=param_name) + + update_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr)) when_args: dict[str, Any] = { "matched": False, diff --git a/sqlspec/builder/mixins/_select_operations.py b/sqlspec/builder/mixins/_select_operations.py index 7afda22c5..0ae1ea764 100644 --- a/sqlspec/builder/mixins/_select_operations.py +++ b/sqlspec/builder/mixins/_select_operations.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from sqlspec.builder._column import Column, FunctionColumn + from sqlspec.core.statement import SQL from sqlspec.protocols import SelectBuilderProtocol, SQLBuilderProtocol __all__ = ("CaseBuilder", "SelectClauseMixin") @@ -27,7 +28,7 @@ class SelectClauseMixin: # Type annotation for PyRight - this will be provided by the base class _expression: Optional[exp.Expression] - def select(self, *columns: Union[str, exp.Expression, "Column", "FunctionColumn"]) -> Self: + def select(self, *columns: Union[str, exp.Expression, "Column", "FunctionColumn", "SQL"]) -> Self: """Add columns to SELECT clause. Raises: @@ -43,10 +44,10 @@ def select(self, *columns: Union[str, exp.Expression, "Column", "FunctionColumn" msg = "Cannot add select columns to a non-SELECT expression." raise SQLBuilderError(msg) for column in columns: - builder._expression = builder._expression.select(parse_column_expression(column), copy=False) + builder._expression = builder._expression.select(parse_column_expression(column, builder), copy=False) return cast("Self", builder) - def distinct(self, *columns: Union[str, exp.Expression, "Column", "FunctionColumn"]) -> Self: + def distinct(self, *columns: Union[str, exp.Expression, "Column", "FunctionColumn", "SQL"]) -> Self: """Add DISTINCT clause to SELECT. Args: @@ -67,7 +68,7 @@ def distinct(self, *columns: Union[str, exp.Expression, "Column", "FunctionColum if not columns: builder._expression.set("distinct", exp.Distinct()) else: - distinct_columns = [parse_column_expression(column) for column in columns] + distinct_columns = [parse_column_expression(column, builder) for column in columns] builder._expression.set("distinct", exp.Distinct(expressions=distinct_columns)) return cast("Self", builder) diff --git a/sqlspec/builder/mixins/_update_operations.py b/sqlspec/builder/mixins/_update_operations.py index daff5505c..8fe8e47d1 100644 --- a/sqlspec/builder/mixins/_update_operations.py +++ b/sqlspec/builder/mixins/_update_operations.py @@ -1,7 +1,7 @@ """Update operation mixins for SQL builders.""" from collections.abc import Mapping -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from mypy_extensions import trait from sqlglot import exp @@ -61,6 +61,52 @@ def _generate_unique_parameter_name(self, base_name: str) -> str: msg = "Method must be provided by QueryBuilder subclass" raise NotImplementedError(msg) + def _process_update_value(self, val: Any, col: Any) -> exp.Expression: + """Process a value for UPDATE assignment, handling SQL objects and parameters. + + Args: + val: The value to process + col: The column name for parameter naming + + Returns: + The processed expression for the value + """ + if isinstance(val, exp.Expression): + return val + if has_query_builder_parameters(val): + subquery = val.build() + sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery) + value_expr = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(self, "dialect", None))) + if has_query_builder_parameters(val): + for p_name, p_value in val.parameters.items(): + self.add_parameter(p_value, name=p_name) + return value_expr + if hasattr(val, "expression") and hasattr(val, "sql"): + # Handle SQL objects (from sql.raw with parameters) + expression = getattr(val, "expression", None) + if expression is not None and isinstance(expression, exp.Expression): + # Merge parameters from SQL object into builder + if hasattr(val, "parameters"): + sql_parameters = getattr(val, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self.add_parameter(param_value, name=param_name) + return cast("exp.Expression", expression) + # If expression is None, fall back to parsing the raw SQL + sql_text = getattr(val, "sql", "") + # Merge parameters even when parsing raw SQL + if hasattr(val, "parameters"): + sql_parameters = getattr(val, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + self.add_parameter(param_value, name=param_name) + parsed_expr = exp.maybe_parse(sql_text) + return parsed_expr if parsed_expr is not None else exp.convert(str(sql_text)) + column_name = col if isinstance(col, str) else str(col) + if "." in column_name: + column_name = column_name.split(".")[-1] + param_name = self._generate_unique_parameter_name(column_name) + param_name = self.add_parameter(val, name=param_name)[1] + return exp.Placeholder(this=param_name) + def set(self, *args: Any, **kwargs: Any) -> Self: """Set columns and values for the UPDATE statement. @@ -80,7 +126,6 @@ def set(self, *args: Any, **kwargs: Any) -> Self: Returns: The current builder instance for method chaining. """ - if self._expression is None: self._expression = exp.Update() if not isinstance(self._expression, exp.Update): @@ -90,42 +135,12 @@ def set(self, *args: Any, **kwargs: Any) -> Self: if len(args) == MIN_SET_ARGS and not kwargs: col, val = args col_expr = col if isinstance(col, exp.Column) else exp.column(col) - if isinstance(val, exp.Expression): - value_expr = val - elif has_query_builder_parameters(val): - subquery = val.build() - sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery) - value_expr = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(self, "dialect", None))) - if has_query_builder_parameters(val): - for p_name, p_value in val.parameters.items(): - self.add_parameter(p_value, name=p_name) - else: - column_name = col if isinstance(col, str) else str(col) - if "." in column_name: - column_name = column_name.split(".")[-1] - param_name = self._generate_unique_parameter_name(column_name) - param_name = self.add_parameter(val, name=param_name)[1] - value_expr = exp.Placeholder(this=param_name) + value_expr = self._process_update_value(val, col) assignments.append(exp.EQ(this=col_expr, expression=value_expr)) elif (len(args) == 1 and isinstance(args[0], Mapping)) or kwargs: all_values = dict(args[0] if args else {}, **kwargs) for col, val in all_values.items(): - if isinstance(val, exp.Expression): - value_expr = val - elif has_query_builder_parameters(val): - subquery = val.build() - sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery) - value_expr = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(self, "dialect", None))) - if has_query_builder_parameters(val): - for p_name, p_value in val.parameters.items(): - self.add_parameter(p_value, name=p_name) - else: - column_name = col if isinstance(col, str) else str(col) - if "." in column_name: - column_name = column_name.split(".")[-1] - param_name = self._generate_unique_parameter_name(column_name) - param_name = self.add_parameter(val, name=param_name)[1] - value_expr = exp.Placeholder(this=param_name) + value_expr = self._process_update_value(val, col) assignments.append(exp.EQ(this=exp.column(col), expression=value_expr)) else: msg = "Invalid arguments for set(): use (column, value), mapping, or kwargs." diff --git a/sqlspec/builder/mixins/_where_clause.py b/sqlspec/builder/mixins/_where_clause.py index 159ae935c..36aeb985c 100644 --- a/sqlspec/builder/mixins/_where_clause.py +++ b/sqlspec/builder/mixins/_where_clause.py @@ -3,6 +3,9 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast +if TYPE_CHECKING: + from sqlspec.core.statement import SQL + from mypy_extensions import trait from sqlglot import exp from typing_extensions import Self @@ -208,7 +211,9 @@ def _process_tuple_condition(self, condition: tuple) -> exp.Expression: def where( self, - condition: Union[str, exp.Expression, exp.Condition, tuple[str, Any], tuple[str, str, Any], "ColumnExpression"], + condition: Union[ + str, exp.Expression, exp.Condition, tuple[str, Any], tuple[str, str, Any], "ColumnExpression", "SQL" + ], value: Optional[Any] = None, operator: Optional[str] = None, ) -> Self: @@ -267,6 +272,25 @@ def where( where_expr = builder._parameterize_expression(raw_expr) else: where_expr = parse_condition_expression(str(condition)) + elif hasattr(condition, "expression") and hasattr(condition, "sql"): + # Handle SQL objects (from sql.raw with parameters) + expression = getattr(condition, "expression", None) + if expression is not None and isinstance(expression, exp.Expression): + # Merge parameters from SQL object into builder + if hasattr(condition, "parameters") and hasattr(builder, "add_parameter"): + sql_parameters = getattr(condition, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + builder.add_parameter(param_value, name=param_name) + where_expr = expression + else: + # If expression is None, fall back to parsing the raw SQL + sql_text = getattr(condition, "sql", "") + # Merge parameters even when parsing raw SQL + if hasattr(condition, "parameters") and hasattr(builder, "add_parameter"): + sql_parameters = getattr(condition, "parameters", {}) + for param_name, param_value in sql_parameters.items(): + builder.add_parameter(param_value, name=param_name) + where_expr = parse_condition_expression(sql_text) else: msg = f"Unsupported condition type: {type(condition).__name__}" raise SQLBuilderError(msg) @@ -596,7 +620,6 @@ class HavingClauseMixin: __slots__ = () - # Type annotation for PyRight - this will be provided by the base class _expression: Optional[exp.Expression] def having(self, condition: Union[str, exp.Expression]) -> Self: diff --git a/tests/unit/test_builder/test_insert_builder.py b/tests/unit/test_builder/test_insert_builder.py new file mode 100644 index 000000000..8891acc36 --- /dev/null +++ b/tests/unit/test_builder/test_insert_builder.py @@ -0,0 +1,321 @@ +"""Unit tests for INSERT builder functionality including ON CONFLICT operations.""" + +import pytest + +from sqlspec import sql +from sqlspec.exceptions import SQLBuilderError + + +def test_insert_basic_functionality() -> None: + """Test basic INSERT builder functionality.""" + query = sql.insert("users").columns("name", "email").values("John", "john@test.com") + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert '"users"' in stmt.sql or "users" in stmt.sql + assert "name" in stmt.parameters + assert "email" in stmt.parameters + assert stmt.parameters["name"] == "John" + assert stmt.parameters["email"] == "john@test.com" + + +def test_insert_with_table_in_constructor() -> None: + """Test INSERT with table specified in constructor.""" + query = sql.insert("products").values(name="Widget", price=29.99) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "products" in stmt.sql + assert "name" in stmt.parameters + assert "price" in stmt.parameters + + +def test_insert_values_from_dict() -> None: + """Test INSERT using values_from_dict method.""" + data = {"id": 1, "name": "John", "status": "active"} + query = sql.insert("users").values_from_dict(data) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert len(stmt.parameters) == 3 + assert stmt.parameters["id"] == 1 + assert stmt.parameters["name"] == "John" + assert stmt.parameters["status"] == "active" + + +def test_insert_values_from_dicts_multiple_rows() -> None: + """Test INSERT using values_from_dicts for multiple rows.""" + data = [{"id": 1, "name": "John"}, {"id": 2, "name": "Jane"}, {"id": 3, "name": "Bob"}] + query = sql.insert("users").values_from_dicts(data) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + # Should have parameters for all rows + assert "id" in stmt.parameters + assert "name" in stmt.parameters + assert "id_1" in stmt.parameters + assert "name_1" in stmt.parameters + assert "id_2" in stmt.parameters + assert "name_2" in stmt.parameters + + +def test_insert_with_kwargs() -> None: + """Test INSERT using kwargs in values method.""" + query = sql.insert("products").values(name="Widget", price=29.99, in_stock=True) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "name" in stmt.parameters + assert "price" in stmt.parameters + assert "in_stock" in stmt.parameters + assert stmt.parameters["name"] == "Widget" + assert stmt.parameters["price"] == 29.99 + assert stmt.parameters["in_stock"] is True + + +def test_insert_mixed_args_kwargs_error() -> None: + """Test that mixing positional and keyword arguments raises error.""" + with pytest.raises(SQLBuilderError, match="Cannot mix positional values with keyword values"): + sql.insert("users").values("John", email="john@test.com") + + +def test_insert_multiple_values_calls() -> None: + """Test multiple calls to values() method for multi-row insert.""" + query = sql.insert("users").columns("name", "email").values("John", "john@test.com").values("Jane", "jane@test.com") + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + # Should have parameters for both rows + assert "name" in stmt.parameters + assert "email" in stmt.parameters + assert "name_1" in stmt.parameters + assert "email_1" in stmt.parameters + + +def test_insert_with_returning() -> None: + """Test INSERT with RETURNING clause.""" + query = sql.insert("users").values(name="John").returning("id", "created_at") + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "RETURNING" in stmt.sql + assert "id" in stmt.sql + assert "created_at" in stmt.sql + + +def test_insert_without_table_error() -> None: + """Test that values() without table raises error.""" + with pytest.raises(SQLBuilderError, match="The target table must be set"): + sql.insert().values(name="John") + + +def test_insert_values_columns_mismatch_error() -> None: + """Test that mismatched columns and values raises error.""" + with pytest.raises(SQLBuilderError, match="Number of values"): + sql.insert("users").columns("name", "email").values("John") # Missing email value + + +def test_insert_columns_and_values_consistency() -> None: + """Test that columns and values are consistent.""" + query = sql.insert("users").columns("name", "email", "age").values("John", "john@test.com", 25) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert len(stmt.parameters) == 3 + assert stmt.parameters["name"] == "John" + assert stmt.parameters["email"] == "john@test.com" + assert stmt.parameters["age"] == 25 + + +def test_insert_inconsistent_dict_keys_error() -> None: + """Test that inconsistent dictionary keys in values_from_dicts raises error.""" + data = [ + {"id": 1, "name": "John"}, + {"id": 2, "email": "jane@test.com"}, # Missing name, has email instead + ] + with pytest.raises(SQLBuilderError, match="do not match expected keys"): + sql.insert("users").values_from_dicts(data).build() + + +def test_insert_with_sql_raw_expressions() -> None: + """Test INSERT with sql.raw expressions.""" + query = sql.insert("logs").values(message="Test message", created_at=sql.raw("NOW()"), uuid=sql.raw("UUID()")) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "NOW()" in stmt.sql + assert "UUID()" in stmt.sql + assert "message" in stmt.parameters + assert stmt.parameters["message"] == "Test message" + + +def test_insert_with_sql_raw_parameters() -> None: + """Test INSERT with sql.raw that has parameters.""" + query = sql.insert("users").values( + name="John", computed_field=sql.raw("COALESCE(:fallback, 'default')", fallback="custom") + ) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "COALESCE" in stmt.sql + assert "name" in stmt.parameters + assert "fallback" in stmt.parameters + assert stmt.parameters["name"] == "John" + assert stmt.parameters["fallback"] == "custom" + + +# ON CONFLICT functionality tests +def test_on_conflict_do_nothing_basic() -> None: + """Test basic ON CONFLICT DO NOTHING.""" + query = sql.insert("users").values(id=1, name="John").on_conflict("id").do_nothing() + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO NOTHING" in stmt.sql + assert "id" in stmt.parameters + + +def test_on_conflict_do_update_basic() -> None: + """Test basic ON CONFLICT DO UPDATE.""" + query = sql.insert("users").values(id=1, name="John").on_conflict("id").do_update(name="Updated") + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "SET" in stmt.sql + assert "name_1" in stmt.parameters + assert stmt.parameters["name_1"] == "Updated" + + +def test_on_conflict_multiple_columns() -> None: + """Test ON CONFLICT with multiple columns.""" + query = ( + sql.insert("users") + .values(email="john@test.com", username="john", name="John") + .on_conflict("email", "username") + .do_nothing() + ) + stmt = query.build() + + assert "ON CONFLICT" in stmt.sql + assert "email" in stmt.sql + assert "username" in stmt.sql + + +def test_on_conflict_no_columns() -> None: + """Test ON CONFLICT without specific columns.""" + query = sql.insert("users").values(id=1, name="John").on_conflict().do_nothing() + stmt = query.build() + + assert "ON CONFLICT" in stmt.sql + assert "DO NOTHING" in stmt.sql + # Should not specify columns + assert "ON CONFLICT(" not in stmt.sql + + +def test_on_conflict_do_update_with_sql_raw() -> None: + """Test ON CONFLICT DO UPDATE with sql.raw expressions.""" + query = ( + sql.insert("users") + .values(id=1, name="John") + .on_conflict("id") + .do_update(updated_at=sql.raw("NOW()"), name="Updated") + ) + stmt = query.build() + + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "NOW()" in stmt.sql + assert "name_1" in stmt.parameters + + +def test_on_conflict_convenience_method() -> None: + """Test on_conflict_do_nothing convenience method.""" + query = sql.insert("users").values(id=1, name="John").on_conflict_do_nothing("id") + stmt = query.build() + + assert "ON CONFLICT" in stmt.sql + assert "DO NOTHING" in stmt.sql + + +def test_legacy_on_duplicate_key_update() -> None: + """Test legacy on_duplicate_key_update method.""" + query = ( + sql.insert("users") + .values(id=1, name="John") + .on_duplicate_key_update(name="Updated", updated_at=sql.raw("NOW()")) + ) + stmt = query.build() + + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "NOW()" in stmt.sql + + +def test_on_conflict_chaining() -> None: + """Test ON CONFLICT method chaining.""" + query = ( + sql.insert("users") + .values(id=1, name="John") + .on_conflict("id") + .do_update(name="Updated") + .returning("id", "name") + ) + stmt = query.build() + + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "RETURNING" in stmt.sql + + +def test_on_conflict_type_safety() -> None: + """Test ON CONFLICT method return types for chaining.""" + insert_builder = sql.insert("users").values(id=1, name="John") + + # on_conflict should return ConflictBuilder + conflict_builder = insert_builder.on_conflict("id") + assert hasattr(conflict_builder, "do_nothing") + assert hasattr(conflict_builder, "do_update") + + # do_nothing should return Insert for further chaining + final_builder = conflict_builder.do_nothing() + assert hasattr(final_builder, "returning") + assert hasattr(final_builder, "build") + + +def test_on_conflict_empty_do_update() -> None: + """Test ON CONFLICT DO UPDATE with no arguments.""" + query = sql.insert("users").values(id=1, name="John").on_conflict("id").do_update() + stmt = query.build() + + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + + +def test_on_conflict_parameter_merging() -> None: + """Test that ON CONFLICT properly merges parameters from SQL objects.""" + query = ( + sql.insert("users") + .values(id=1, name="John") + .on_conflict("id") + .do_update(name=sql.raw("COALESCE(:new_name, name)", new_name="Updated"), updated_at=sql.raw("NOW()")) + ) + stmt = query.build() + + assert "new_name" in stmt.parameters + assert stmt.parameters["new_name"] == "Updated" + assert "NOW()" in stmt.sql + + +def test_on_conflict_with_values_from_dict() -> None: + """Test ON CONFLICT with values_from_dict.""" + data = {"id": 1, "name": "John", "email": "john@test.com"} + query = sql.insert("users").values_from_dict(data).on_conflict("id").do_update(name="Updated") + stmt = query.build() + + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "name_1" in stmt.parameters + assert stmt.parameters["name_1"] == "Updated" diff --git a/tests/unit/test_sql_factory.py b/tests/unit/test_sql_factory.py index 6eb2c514a..99178f273 100644 --- a/tests/unit/test_sql_factory.py +++ b/tests/unit/test_sql_factory.py @@ -1,5 +1,7 @@ """Unit tests for SQL factory functionality including parameter binding fixes and new features.""" +import math + import pytest from sqlglot import exp @@ -267,9 +269,9 @@ def test_raw_parameter_overwrite_behavior() -> None: stmt = sql.raw("field1 = :value AND field2 = :value", value="test") assert isinstance(stmt, SQL) - assert stmt.parameters["value"] == "test" assert stmt.sql.count(":value") == 2 assert len(stmt.parameters) == 1 + assert stmt.parameters["value"] == "test" def test_select_method() -> None: @@ -496,7 +498,7 @@ def test_parameter_names_use_column_names() -> None: def test_parameter_values_preserved_correctly() -> None: """Test that parameter values are preserved exactly.""" - test_values = [("string_val", "test"), ("int_val", 42), ("float_val", 3.14159), ("bool_val", True)] + test_values = [("string_val", "test"), ("int_val", 42), ("float_val", math.pi), ("bool_val", True)] query = sql.select("*").from_("test") for column_name, value in test_values: @@ -862,3 +864,660 @@ def test_backward_compatibility_preserved() -> None: assert isinstance(sql.users, Column) assert isinstance(sql.posts, Column) + + +# Tests for type annotation fixes and SQL object compatibility +def test_case_as_method_type_annotation_fix() -> None: + """Test that sql.case().as_() method returns proper type without 'partially unknown' errors.""" + # This test verifies the fix for the original user issue + case_expr = sql.case().when("status = 'active'", "Active").else_("Inactive").as_("status_display") + + # Should be able to use in select without type errors + query = sql.select("id", "name", case_expr).from_("users") + stmt = query.build() + + assert "CASE" in stmt.sql + assert "status_display" in stmt.sql + assert "Active" in stmt.sql + assert "Inactive" in stmt.sql + + # Verify it's properly aliased + assert " AS " in stmt.sql or "status_display" in stmt.sql + + +def test_window_function_as_method_type_annotation_fix() -> None: + """Test that window function as_() method also has proper type annotations.""" + window_func = sql.row_number_.partition_by("department").order_by("salary").as_("row_num") + + query = sql.select("name", window_func).from_("employees") + stmt = query.build() + + assert "ROW_NUMBER()" in stmt.sql + assert "row_num" in stmt.sql + assert "OVER" in stmt.sql + + +def test_sql_raw_object_in_select_clause() -> None: + """Test that SQL objects from sql.raw work in SELECT clauses with parameter merging.""" + raw_expr = sql.raw("COALESCE(name, :default_name)", default_name="Unknown") + + query = sql.select("id", raw_expr).from_("users") + stmt = query.build() + + assert "COALESCE" in stmt.sql + assert "default_name" in stmt.parameters + assert stmt.parameters["default_name"] == "Unknown" + assert ":default_name" in stmt.sql + + +def test_sql_raw_object_in_join_conditions() -> None: + """Test that SQL objects from sql.raw work in JOIN conditions with parameter merging.""" + join_condition = sql.raw("users.id = posts.user_id AND posts.status = :status", status="published") + + query = sql.select("users.name", "posts.title").from_("users").left_join("posts", join_condition) + stmt = query.build() + + assert "LEFT JOIN" in stmt.sql + assert "status" in stmt.parameters + assert stmt.parameters["status"] == "published" + assert ":status" in stmt.sql + + +def test_sql_raw_object_in_where_clauses() -> None: + """Test that SQL objects from sql.raw work in WHERE clauses with parameter merging.""" + where_condition = sql.raw("LENGTH(name) > :min_length", min_length=5) + + query = sql.select("*").from_("users").where(where_condition) + stmt = query.build() + + assert "LENGTH" in stmt.sql + assert "min_length" in stmt.parameters + assert stmt.parameters["min_length"] == 5 + assert ":min_length" in stmt.sql + + +def test_sql_raw_object_in_distinct_clause() -> None: + """Test that SQL objects work in DISTINCT clauses with parameter merging.""" + raw_expr = sql.raw("UPPER(category)") + + query = sql.select("*").from_("products").distinct(raw_expr) + stmt = query.build() + + assert "DISTINCT" in stmt.sql + assert "UPPER" in stmt.sql + + +def test_multiple_sql_raw_objects_parameter_merging() -> None: + """Test that multiple SQL objects properly merge their parameters.""" + select_expr = sql.raw("COALESCE(name, :default_name)", default_name="Unknown") + join_condition = sql.raw("users.id = posts.user_id AND posts.status = :status", status="published") + where_condition = sql.raw("users.created_at > :min_date", min_date="2023-01-01") + + query = sql.select("id", select_expr).from_("users").left_join("posts", join_condition).where(where_condition) + stmt = query.build() + + # All parameters should be merged + assert len(stmt.parameters) == 3 + assert stmt.parameters["default_name"] == "Unknown" + assert stmt.parameters["status"] == "published" + assert stmt.parameters["min_date"] == "2023-01-01" + + # All placeholders should be in SQL + assert ":default_name" in stmt.sql + assert ":status" in stmt.sql + assert ":min_date" in stmt.sql + + +def test_sql_raw_without_parameters_still_works() -> None: + """Test that SQL objects without parameters still work correctly.""" + raw_expr = sql.raw("NOW()") + + query = sql.select("id", raw_expr).from_("logs") + stmt = query.build() + + assert "NOW()" in stmt.sql + assert len(stmt.parameters) == 0 + + +def test_mixed_sql_objects_and_regular_parameters() -> None: + """Test mixing SQL objects with regular builder parameters.""" + raw_expr = sql.raw("UPPER(name)") + + query = ( + sql.select("id", raw_expr) + .from_("users") + .where_eq("status", "active") # Regular parameter + .where(sql.raw("created_at > :min_date", min_date="2023-01-01")) # SQL object parameter + ) + stmt = query.build() + + # Should have both types of parameters + assert "status" in stmt.parameters + assert "min_date" in stmt.parameters + assert stmt.parameters["status"] == "active" + assert stmt.parameters["min_date"] == "2023-01-01" + + assert "UPPER" in stmt.sql + assert ":status" in stmt.sql + assert ":min_date" in stmt.sql + + +def test_sql_raw_parameter_name_conflicts_handled() -> None: + """Test that parameter name conflicts are detected when merging SQL objects.""" + # Create two SQL objects with different parameter names (should work fine) + raw_expr1 = sql.raw("COALESCE(name, :value)", value="default1") + raw_expr2 = sql.raw("COALESCE(email, :other_value)", other_value="default2") + + query = sql.select("id", raw_expr1, raw_expr2).from_("users") + stmt = query.build() + + assert "value" in stmt.parameters + assert "other_value" in stmt.parameters + assert stmt.parameters["value"] == "default1" + assert stmt.parameters["other_value"] == "default2" + + # Test that actual conflicts are detected + raw_conflict1 = sql.raw("COALESCE(name, :conflict)", conflict="first") + raw_conflict2 = sql.raw("COALESCE(email, :conflict)", conflict="second") + + with pytest.raises(SQLBuilderError, match="Parameter name 'conflict' already exists"): + sql.select("id", raw_conflict1, raw_conflict2).from_("users").build() + + +def test_original_user_case_example_regression_test() -> None: + """Regression test for the exact user example that was failing.""" + # This was the original failing example + case_expr = sql.case().when("password IS NOT NULL", True).else_(False).as_("has_password") + + query = sql.select("id", "name", case_expr).from_("users") + stmt = query.build() + + # Should work without type annotation errors + assert "CASE" in stmt.sql + assert "has_password" in stmt.sql + assert "password" in stmt.sql and ("NULL" in stmt.sql or "IS" in stmt.sql) + + # Should also work in UPDATE operations + update_query = sql.update("users").set({"last_check": sql.raw("NOW()")}).where(case_expr) + update_stmt = update_query.build() + + assert "UPDATE" in update_stmt.sql + assert "CASE" in update_stmt.sql + + +def test_type_compatibility_across_all_operations() -> None: + """Test that SQL objects work across all major SQL operations.""" + # Test in various contexts to ensure type compatibility + raw_condition = sql.raw("LENGTH(name) > :min_len", min_len=3) + raw_value = sql.raw("UPPER(:new_name)", new_name="test") + raw_select = sql.raw("COUNT(*) as total") + + # SELECT with SQL objects + select_query = sql.select("id", raw_select).from_("users").where(raw_condition) + select_stmt = select_query.build() + assert "COUNT(*)" in select_stmt.sql + assert "min_len" in select_stmt.parameters + + # UPDATE with SQL objects using kwargs + update_query = sql.update("users").set(name=raw_value, status="updated").where(raw_condition) + update_stmt = update_query.build() + assert "UPDATE" in update_stmt.sql + assert "min_len" in update_stmt.parameters + assert "new_name" in update_stmt.parameters + assert "status" in update_stmt.parameters + + # DELETE with SQL objects + delete_query = sql.delete().from_("users").where(raw_condition) + delete_stmt = delete_query.build() + assert "DELETE" in delete_stmt.sql + assert "min_len" in delete_stmt.parameters + + +def test_update_set_method_with_sql_objects() -> None: + """Test that UPDATE.set() method properly handles SQL objects with kwargs.""" + raw_timestamp = sql.raw("NOW()") + raw_computed = sql.raw("UPPER(:value)", value="test") + + # Test using kwargs with SQL objects + query = ( + sql.update("users").set(name="John", last_updated=raw_timestamp, computed_field=raw_computed).where_eq("id", 1) + ) + + stmt = query.build() + + assert "UPDATE" in stmt.sql + assert "NOW()" in stmt.sql + assert "UPPER" in stmt.sql + assert "name" in stmt.parameters + assert "value" in stmt.parameters + assert "id" in stmt.parameters + assert stmt.parameters["name"] == "John" + assert stmt.parameters["value"] == "test" + assert stmt.parameters["id"] == 1 + + +def test_update_set_method_backward_compatibility() -> None: + """Test that UPDATE.set() method maintains backward compatibility with dict.""" + raw_timestamp = sql.raw("NOW()") + + # Test using dict (original API) + query1 = sql.update("users").set({"name": "John", "updated_at": raw_timestamp}) + stmt1 = query1.build() + + assert "UPDATE" in stmt1.sql + assert "NOW()" in stmt1.sql + assert "name" in stmt1.parameters + assert stmt1.parameters["name"] == "John" + + # Test using positional args (column, value) + query2 = sql.update("users").set("status", "active") + stmt2 = query2.build() + + assert "UPDATE" in stmt2.sql + assert "status" in stmt2.parameters + assert stmt2.parameters["status"] == "active" + + +# Tests for ON CONFLICT functionality +def test_on_conflict_do_nothing_basic() -> None: + """Test basic ON CONFLICT DO NOTHING functionality.""" + query = sql.insert("users").columns("id", "name").values(1, "John").on_conflict("id").do_nothing() + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO NOTHING" in stmt.sql + assert '"id"' in stmt.sql or "id" in stmt.sql + assert "id" in stmt.parameters + assert "name" in stmt.parameters + assert stmt.parameters["id"] == 1 + assert stmt.parameters["name"] == "John" + + +def test_on_conflict_do_nothing_multiple_columns() -> None: + """Test ON CONFLICT DO NOTHING with multiple conflict columns.""" + query = ( + sql.insert("users") + .columns("email", "username", "name") + .values("john@test.com", "john", "John") + .on_conflict("email", "username") + .do_nothing() + ) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO NOTHING" in stmt.sql + assert "email" in stmt.sql and "username" in stmt.sql + assert "email" in stmt.parameters + assert "username" in stmt.parameters + assert "name" in stmt.parameters + + +def test_on_conflict_do_nothing_no_columns() -> None: + """Test ON CONFLICT DO NOTHING without specific columns (catches all conflicts).""" + query = sql.insert("users").columns("id", "name").values(1, "John").on_conflict().do_nothing() + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO NOTHING" in stmt.sql + # Should not have specific columns in conflict clause + assert "ON CONFLICT(" not in stmt.sql + + +def test_on_conflict_do_update_basic() -> None: + """Test basic ON CONFLICT DO UPDATE functionality.""" + query = sql.insert("users").columns("id", "name").values(1, "John").on_conflict("id").do_update(name="Updated John") + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "SET" in stmt.sql + assert "id" in stmt.parameters + assert "name" in stmt.parameters + assert "name_1" in stmt.parameters # The update parameter + assert stmt.parameters["id"] == 1 + assert stmt.parameters["name"] == "John" + assert stmt.parameters["name_1"] == "Updated John" + + +def test_on_conflict_do_update_multiple_values() -> None: + """Test ON CONFLICT DO UPDATE with multiple update values.""" + query = ( + sql.insert("users") + .columns("id", "name", "email") + .values(1, "John", "john@test.com") + .on_conflict("id") + .do_update(name="Updated John", email="updated@test.com") + ) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "SET" in stmt.sql + assert "name_1" in stmt.parameters + assert "email_1" in stmt.parameters + assert stmt.parameters["name_1"] == "Updated John" + assert stmt.parameters["email_1"] == "updated@test.com" + + +def test_on_conflict_do_update_with_sql_raw() -> None: + """Test ON CONFLICT DO UPDATE with sql.raw expressions.""" + query = ( + sql.insert("users") + .columns("id", "name") + .values(1, "John") + .on_conflict("id") + .do_update(updated_at=sql.raw("NOW()"), name="Updated") + ) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "SET" in stmt.sql + assert "NOW()" in stmt.sql + assert "name_1" in stmt.parameters + assert stmt.parameters["name_1"] == "Updated" + + +def test_on_conflict_do_update_with_sql_raw_parameters() -> None: + """Test ON CONFLICT DO UPDATE with sql.raw that has parameters.""" + query = ( + sql.insert("users") + .columns("id", "name") + .values(1, "John") + .on_conflict("id") + .do_update( + updated_at=sql.raw("NOW()"), status=sql.raw("COALESCE(:new_status, 'active')", new_status="verified") + ) + ) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "NOW()" in stmt.sql + assert "COALESCE" in stmt.sql + assert "new_status" in stmt.parameters + assert stmt.parameters["new_status"] == "verified" + + +def test_on_conflict_convenience_method() -> None: + """Test the convenience method on_conflict_do_nothing.""" + query = sql.insert("users").columns("id", "name").values(1, "John").on_conflict_do_nothing("id") + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO NOTHING" in stmt.sql + assert "id" in stmt.parameters + assert stmt.parameters["id"] == 1 + + +def test_legacy_on_duplicate_key_update_method() -> None: + """Test that the legacy on_duplicate_key_update method uses the new ON CONFLICT API.""" + query = ( + sql.insert("users") + .columns("id", "name") + .values(1, "John") + .on_duplicate_key_update(name="Updated", updated_at=sql.raw("NOW()")) + ) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "SET" in stmt.sql + assert "NOW()" in stmt.sql + assert "name_1" in stmt.parameters + assert stmt.parameters["name_1"] == "Updated" + + +def test_on_conflict_with_insert_from_dict() -> None: + """Test ON CONFLICT with insert using from_dict methods.""" + data = {"id": 1, "name": "John", "email": "john@test.com"} + query = sql.insert("users").values_from_dict(data).on_conflict("id").do_update(name="Updated John") + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "id" in stmt.parameters + assert "name" in stmt.parameters + assert "email" in stmt.parameters + assert "name_1" in stmt.parameters # The update parameter + assert stmt.parameters["name_1"] == "Updated John" + + +def test_on_conflict_with_multiple_rows() -> None: + """Test ON CONFLICT with multiple value rows.""" + query = sql.insert("users").columns("id", "name").values(1, "John").values(2, "Jane").on_conflict("id").do_nothing() + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO NOTHING" in stmt.sql + # Check that both rows are included + assert "id" in stmt.parameters + assert "name" in stmt.parameters + assert "id_1" in stmt.parameters + assert "name_1" in stmt.parameters + + +def test_on_conflict_chaining_with_returning() -> None: + """Test ON CONFLICT chaining with RETURNING clause.""" + query = ( + sql.insert("users") + .columns("id", "name") + .values(1, "John") + .on_conflict("id") + .do_update(name="Updated John") + .returning("id", "name", "updated_at") + ) + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + assert "RETURNING" in stmt.sql + assert "id" in stmt.sql and "name" in stmt.sql and "updated_at" in stmt.sql + + +def test_on_conflict_empty_do_update() -> None: + """Test ON CONFLICT DO UPDATE with no arguments (should work but do nothing).""" + query = sql.insert("users").columns("id", "name").values(1, "John").on_conflict("id").do_update() + stmt = query.build() + + assert "INSERT INTO" in stmt.sql + assert "ON CONFLICT" in stmt.sql + assert "DO UPDATE" in stmt.sql + + +def test_on_conflict_sql_generation_postgres_style() -> None: + """Test that ON CONFLICT generates PostgreSQL-style syntax that SQLGlot can transpile.""" + query = sql.insert("users").columns("id", "name").values(1, "John").on_conflict("id").do_update(name="Updated") + stmt = query.build() + + # Should generate proper PostgreSQL ON CONFLICT syntax + assert "ON CONFLICT(" in stmt.sql or "ON CONFLICT (" in stmt.sql + assert "DO UPDATE SET" in stmt.sql + assert '"id"' in stmt.sql or "id" in stmt.sql + + +def test_on_conflict_type_safety() -> None: + """Test that ON CONFLICT methods return proper types for method chaining.""" + # This test ensures the ConflictBuilder properly returns Insert builder + query_builder = sql.insert("users").columns("id", "name").values(1, "John") + + # on_conflict should return ConflictBuilder + conflict_builder = query_builder.on_conflict("id") + assert hasattr(conflict_builder, "do_nothing") + assert hasattr(conflict_builder, "do_update") + + # do_nothing should return Insert builder for further chaining + final_builder = conflict_builder.do_nothing() + assert hasattr(final_builder, "returning") + assert hasattr(final_builder, "build") + + # Should be able to continue chaining + final_query = final_builder.returning("id") + stmt = final_query.build() + assert "RETURNING" in stmt.sql + + +# Tests for MERGE kwargs functionality +def test_merge_when_matched_then_update_with_kwargs() -> None: + """Test MERGE when_matched_then_update with kwargs support.""" + query = ( + sql.merge("users") + .using("new_users") + .on("users.id = new_users.id") + .when_matched_then_update(name="Updated John", email="updated@test.com") + ) + stmt = query.build() + + assert "MERGE INTO" in stmt.sql + assert "WHEN MATCHED THEN UPDATE" in stmt.sql + assert "name" in stmt.parameters + assert "email" in stmt.parameters + assert stmt.parameters["name"] == "Updated John" + assert stmt.parameters["email"] == "updated@test.com" + + +def test_merge_when_matched_then_update_mixed_dict_kwargs() -> None: + """Test MERGE when_matched_then_update with mixed dict and kwargs.""" + query = ( + sql.merge("users") + .using("new_users") + .on("users.id = new_users.id") + .when_matched_then_update({"name": "Dict Name"}, email="Kwargs Email", status="active") + ) + stmt = query.build() + + assert "MERGE INTO" in stmt.sql + assert "WHEN MATCHED THEN UPDATE" in stmt.sql + assert "name" in stmt.parameters + assert "email" in stmt.parameters + assert "status" in stmt.parameters + assert stmt.parameters["name"] == "Dict Name" + assert stmt.parameters["email"] == "Kwargs Email" + assert stmt.parameters["status"] == "active" + + +def test_merge_when_matched_then_update_with_sql_raw() -> None: + """Test MERGE when_matched_then_update with sql.raw expressions.""" + query = ( + sql.merge("users") + .using("new_users") + .on("users.id = new_users.id") + .when_matched_then_update( + name="Updated John", + updated_at=sql.raw("NOW()"), + status=sql.raw("COALESCE(:new_status, 'active')", new_status="verified"), + ) + ) + stmt = query.build() + + assert "MERGE INTO" in stmt.sql + assert "WHEN MATCHED THEN UPDATE" in stmt.sql + assert "NOW()" in stmt.sql + assert "COALESCE" in stmt.sql + assert "name" in stmt.parameters + assert "new_status" in stmt.parameters + assert stmt.parameters["name"] == "Updated John" + assert stmt.parameters["new_status"] == "verified" + + +def test_merge_when_not_matched_by_source_then_update_with_kwargs() -> None: + """Test MERGE when_not_matched_by_source_then_update with kwargs support.""" + query = ( + sql.merge("users") + .using("new_users") + .on("users.id = new_users.id") + .when_not_matched_by_source_then_update(status="inactive", last_seen=sql.raw("NOW()")) + ) + stmt = query.build() + + assert "MERGE INTO" in stmt.sql + assert "WHEN NOT MATCHED BY SOURCE THEN UPDATE" in stmt.sql + assert "NOW()" in stmt.sql + assert "status" in stmt.parameters + assert stmt.parameters["status"] == "inactive" + + +def test_merge_when_not_matched_by_source_then_update_mixed() -> None: + """Test MERGE when_not_matched_by_source_then_update with mixed dict and kwargs.""" + query = ( + sql.merge("users") + .using("new_users") + .on("users.id = new_users.id") + .when_not_matched_by_source_then_update( + {"status": "Dict Status"}, last_seen=sql.raw("NOW()"), notes="Kwargs Notes" + ) + ) + stmt = query.build() + + assert "MERGE INTO" in stmt.sql + assert "WHEN NOT MATCHED BY SOURCE THEN UPDATE" in stmt.sql + assert "NOW()" in stmt.sql + assert "status" in stmt.parameters + assert "notes" in stmt.parameters + assert stmt.parameters["status"] == "Dict Status" + assert stmt.parameters["notes"] == "Kwargs Notes" + + +def test_merge_empty_update_values_error() -> None: + """Test that MERGE update methods raise error when no values provided.""" + merge_builder = sql.merge("users").using("new_users").on("users.id = new_users.id") + + with pytest.raises(SQLBuilderError, match="No update values provided"): + merge_builder.when_matched_then_update() + + with pytest.raises(SQLBuilderError, match="No update values provided"): + merge_builder.when_not_matched_by_source_then_update() + + +def test_merge_backward_compatibility() -> None: + """Test that MERGE methods maintain backward compatibility with dict-only usage.""" + query = ( + sql.merge("users") + .using("new_users") + .on("users.id = new_users.id") + .when_matched_then_update({"name": "Updated", "email": "updated@test.com"}) + .when_not_matched_by_source_then_update({"status": "inactive"}) + ) + stmt = query.build() + + assert "MERGE INTO" in stmt.sql + assert "WHEN MATCHED THEN UPDATE" in stmt.sql + assert "WHEN NOT MATCHED BY SOURCE THEN UPDATE" in stmt.sql + assert "name" in stmt.parameters + assert "email" in stmt.parameters + assert "status" in stmt.parameters + + +def test_merge_comprehensive_example() -> None: + """Test comprehensive MERGE example with all features.""" + query = ( + sql.merge("users") + .using("new_users") + .on("users.id = new_users.id") + .when_matched_then_update(name="new_users.name", email="new_users.email", updated_at=sql.raw("NOW()")) + .when_not_matched_then_insert( + ["id", "name", "email", "created_at"], + ["new_users.id", "new_users.name", "new_users.email", sql.raw("NOW()")], + ) + .when_not_matched_by_source_then_update(status="archived", archived_at=sql.raw("NOW()")) + ) + stmt = query.build() + + assert "MERGE INTO" in stmt.sql + assert "WHEN MATCHED THEN UPDATE" in stmt.sql + assert "WHEN NOT MATCHED THEN INSERT" in stmt.sql + assert "WHEN NOT MATCHED BY SOURCE THEN UPDATE" in stmt.sql + assert "NOW()" in stmt.sql + assert len(stmt.parameters) >= 6 # Should have multiple parameters diff --git a/uv.lock b/uv.lock index 743feb275..5ca054876 100644 --- a/uv.lock +++ b/uv.lock @@ -4957,7 +4957,7 @@ wheels = [ [[package]] name = "sqlspec" -version = "0.16.1" +version = "0.16.2" source = { editable = "." } dependencies = [ { name = "eval-type-backport", marker = "python_full_version < '3.10'" },