Skip to content

Commit 89fdd7b

Browse files
authored
feat: add FOR UPDATE locking support to query builder (#88)
Builder improvements related to row locking: - Add row-level locking support with `FOR UPDATE`, `FOR SHARE`, `SKIP LOCKED`, and `NOWAIT` clauses - Implement support for table-specific locking with `OF` clause
1 parent efc1aa5 commit 89fdd7b

File tree

28 files changed

+1940
-385
lines changed

28 files changed

+1940
-385
lines changed

sqlspec/_sql.py

Lines changed: 12 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
MathExpression,
4141
StringExpression,
4242
)
43+
from sqlspec.builder._parsing_utils import extract_expression, to_expression
4344
from sqlspec.builder.mixins._join_operations import JoinBuilder
4445
from sqlspec.builder.mixins._select_operations import Case, SubqueryBuilder, WindowFunctionBuilder
4546
from sqlspec.core.statement import SQL
@@ -746,7 +747,7 @@ def count(
746747
if isinstance(column, str) and column == "*":
747748
expr = exp.Count(this=exp.Star(), distinct=distinct)
748749
else:
749-
col_expr = self._extract_expression(column)
750+
col_expr = extract_expression(column)
750751
expr = exp.Count(this=col_expr, distinct=distinct)
751752
return AggregateExpression(expr)
752753

@@ -774,7 +775,7 @@ def sum(
774775
Returns:
775776
SUM expression.
776777
"""
777-
col_expr = SQLFactory._extract_expression(column)
778+
col_expr = extract_expression(column)
778779
return AggregateExpression(exp.Sum(this=col_expr, distinct=distinct))
779780

780781
@staticmethod
@@ -787,7 +788,7 @@ def avg(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> Aggr
787788
Returns:
788789
AVG expression.
789790
"""
790-
col_expr = SQLFactory._extract_expression(column)
791+
col_expr = extract_expression(column)
791792
return AggregateExpression(exp.Avg(this=col_expr))
792793

793794
@staticmethod
@@ -800,7 +801,7 @@ def max(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> Aggr
800801
Returns:
801802
MAX expression.
802803
"""
803-
col_expr = SQLFactory._extract_expression(column)
804+
col_expr = extract_expression(column)
804805
return AggregateExpression(exp.Max(this=col_expr))
805806

806807
@staticmethod
@@ -813,7 +814,7 @@ def min(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> Aggr
813814
Returns:
814815
MIN expression.
815816
"""
816-
col_expr = SQLFactory._extract_expression(column)
817+
col_expr = extract_expression(column)
817818
return AggregateExpression(exp.Min(this=col_expr))
818819

819820
@staticmethod
@@ -1034,45 +1035,6 @@ def to_literal(value: Any) -> FunctionExpression:
10341035
return FunctionExpression(value)
10351036
return FunctionExpression(exp.convert(value))
10361037

1037-
@staticmethod
1038-
def _to_expression(value: Any) -> exp.Expression:
1039-
"""Convert a Python value to a raw SQLGlot expression.
1040-
1041-
Args:
1042-
value: Python value or SQLGlot expression to convert.
1043-
1044-
Returns:
1045-
Raw SQLGlot expression.
1046-
"""
1047-
if isinstance(value, exp.Expression):
1048-
return value
1049-
return exp.convert(value)
1050-
1051-
@staticmethod
1052-
def _extract_expression(value: Any) -> exp.Expression:
1053-
"""Extract SQLGlot expression from value, handling our wrapper types.
1054-
1055-
Args:
1056-
value: String, SQLGlot expression, or our wrapper type.
1057-
1058-
Returns:
1059-
Raw SQLGlot expression.
1060-
"""
1061-
from sqlspec.builder._expression_wrappers import ExpressionWrapper
1062-
from sqlspec.builder.mixins._select_operations import Case
1063-
1064-
if isinstance(value, str):
1065-
return exp.column(value)
1066-
if isinstance(value, Column):
1067-
return value.sqlglot_expression
1068-
if isinstance(value, ExpressionWrapper):
1069-
return value.expression
1070-
if isinstance(value, Case):
1071-
return exp.Case(ifs=value.conditions, default=value.default)
1072-
if isinstance(value, exp.Expression):
1073-
return value
1074-
return exp.convert(value)
1075-
10761038
@staticmethod
10771039
def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression, Any]) -> FunctionExpression:
10781040
"""Create a DECODE expression (Oracle-style conditional logic).
@@ -1109,14 +1071,14 @@ def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression,
11091071

11101072
for i in range(0, len(args) - 1, 2):
11111073
if i + 1 >= len(args):
1112-
default = SQLFactory._to_expression(args[i])
1074+
default = to_expression(args[i])
11131075
break
11141076

11151077
search_val = args[i]
11161078
result_val = args[i + 1]
11171079

1118-
search_expr = SQLFactory._to_expression(search_val)
1119-
result_expr = SQLFactory._to_expression(result_val)
1080+
search_expr = to_expression(search_val)
1081+
result_expr = to_expression(result_val)
11201082

11211083
condition = exp.EQ(this=col_expr, expression=search_expr)
11221084
conditions.append(exp.If(this=condition, true=result_expr))
@@ -1164,7 +1126,7 @@ def nvl(
11641126
COALESCE expression equivalent to NVL.
11651127
"""
11661128
col_expr = exp.column(column) if isinstance(column, str) else column
1167-
sub_expr = SQLFactory._to_expression(substitute_value)
1129+
sub_expr = to_expression(substitute_value)
11681130
return ConversionExpression(exp.Coalesce(expressions=[col_expr, sub_expr]))
11691131

11701132
@staticmethod
@@ -1192,8 +1154,8 @@ def nvl2(
11921154
```
11931155
"""
11941156
col_expr = exp.column(column) if isinstance(column, str) else column
1195-
not_null_expr = SQLFactory._to_expression(value_if_not_null)
1196-
null_expr = SQLFactory._to_expression(value_if_null)
1157+
not_null_expr = to_expression(value_if_not_null)
1158+
null_expr = to_expression(value_if_null)
11971159

11981160
is_null = exp.Is(this=col_expr, expression=exp.Null())
11991161
condition = exp.Not(this=is_null)

sqlspec/builder/_base.py

Lines changed: 82 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
Provides abstract base classes and core functionality for SQL query builders.
44
"""
55

6+
import hashlib
7+
import uuid
68
from abc import ABC, abstractmethod
79
from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union, cast
810

@@ -19,13 +21,15 @@
1921
from sqlspec.core.statement import SQL, StatementConfig
2022
from sqlspec.exceptions import SQLBuilderError
2123
from sqlspec.utils.logging import get_logger
22-
from sqlspec.utils.type_guards import has_expression_and_parameters, has_sql_method, has_with_method
24+
from sqlspec.utils.type_guards import has_expression_and_parameters, has_sql_method, has_with_method, is_expression
2325

2426
if TYPE_CHECKING:
2527
from sqlspec.core.result import SQLResult
2628

2729
__all__ = ("QueryBuilder", "SafeQuery")
2830

31+
MAX_PARAMETER_COLLISION_ATTEMPTS = 1000
32+
2933
logger = get_logger(__name__)
3034

3135

@@ -104,13 +108,9 @@ def set_expression(self, expression: exp.Expression) -> None:
104108
105109
Args:
106110
expression: SQLGlot expression to set
107-
108-
Raises:
109-
TypeError: If expression is not a SQLGlot Expression
110111
"""
111-
if not isinstance(expression, exp.Expression):
112-
msg = f"Expected Expression, got {type(expression)}"
113-
raise TypeError(msg)
112+
if not is_expression(expression):
113+
self._raise_invalid_expression_type(expression)
114114
self._expression = expression
115115

116116
def has_expression(self) -> bool:
@@ -151,6 +151,46 @@ def _raise_sql_builder_error(message: str, cause: Optional[BaseException] = None
151151
"""
152152
raise SQLBuilderError(message) from cause
153153

154+
@staticmethod
155+
def _raise_invalid_expression_type(expression: Any) -> NoReturn:
156+
"""Raise error for invalid expression type.
157+
158+
Args:
159+
expression: The invalid expression object
160+
161+
Raises:
162+
TypeError: Always raised for type mismatch
163+
"""
164+
msg = f"Expected Expression, got {type(expression)}"
165+
raise TypeError(msg)
166+
167+
@staticmethod
168+
def _raise_cte_query_error(alias: str, message: str) -> NoReturn:
169+
"""Raise error for CTE query issues.
170+
171+
Args:
172+
alias: CTE alias name
173+
message: Specific error message
174+
175+
Raises:
176+
SQLBuilderError: Always raised for CTE errors
177+
"""
178+
msg = f"CTE '{alias}': {message}"
179+
raise SQLBuilderError(msg)
180+
181+
@staticmethod
182+
def _raise_cte_parse_error(cause: BaseException) -> NoReturn:
183+
"""Raise error for CTE parsing failures.
184+
185+
Args:
186+
cause: The original parsing exception
187+
188+
Raises:
189+
SQLBuilderError: Always raised with chained cause
190+
"""
191+
msg = f"Failed to parse CTE query: {cause!s}"
192+
raise SQLBuilderError(msg) from cause
193+
154194
def _add_parameter(self, value: Any, context: Optional[str] = None) -> str:
155195
"""Adds a parameter to the query and returns its placeholder name.
156196
@@ -229,13 +269,11 @@ def _generate_unique_parameter_name(self, base_name: str) -> str:
229269
if base_name not in self._parameters:
230270
return base_name
231271

232-
for i in range(1, 1000):
272+
for i in range(1, MAX_PARAMETER_COLLISION_ATTEMPTS):
233273
name = f"{base_name}_{i}"
234274
if name not in self._parameters:
235275
return name
236276

237-
import uuid
238-
239277
return f"{base_name}_{uuid.uuid4().hex[:8]}"
240278

241279
def _merge_cte_parameters(self, cte_name: str, parameters: dict[str, Any]) -> dict[str, str]:
@@ -284,8 +322,6 @@ def _generate_builder_cache_key(self, config: "Optional[StatementConfig]" = None
284322
Returns:
285323
A unique cache key representing the builder state and configuration
286324
"""
287-
import hashlib
288-
289325
dialect_name: str = self.dialect_name or "default"
290326

291327
if self._expression is None:
@@ -339,35 +375,29 @@ def with_cte(self: Self, alias: str, query: "Union[QueryBuilder, exp.Select, str
339375
if isinstance(query, QueryBuilder):
340376
query_expr = query.get_expression()
341377
if query_expr is None:
342-
self._raise_sql_builder_error("CTE query builder has no expression.")
378+
self._raise_cte_query_error(alias, "query builder has no expression")
343379
if not isinstance(query_expr, exp.Select):
344-
msg = f"CTE query builder expression must be a Select, got {type(query_expr).__name__}."
345-
self._raise_sql_builder_error(msg)
380+
self._raise_cte_query_error(alias, f"expression must be a Select, got {type(query_expr).__name__}")
346381
cte_select_expression = query_expr
347382
param_mapping = self._merge_cte_parameters(alias, query.parameters)
348-
updated_expression = self._update_placeholders_in_expression(cte_select_expression, param_mapping)
349-
if not isinstance(updated_expression, exp.Select):
350-
msg = f"Updated CTE expression must be a Select, got {type(updated_expression).__name__}."
351-
self._raise_sql_builder_error(msg)
352-
cte_select_expression = updated_expression
383+
cte_select_expression = cast(
384+
"exp.Select", self._update_placeholders_in_expression(cte_select_expression, param_mapping)
385+
)
353386

354387
elif isinstance(query, str):
355388
try:
356389
parsed_expression = sqlglot.parse_one(query, read=self.dialect_name)
357390
if not isinstance(parsed_expression, exp.Select):
358-
msg = f"CTE query string must parse to a SELECT statement, got {type(parsed_expression).__name__}."
359-
self._raise_sql_builder_error(msg)
391+
self._raise_cte_query_error(
392+
alias, f"query string must parse to SELECT, got {type(parsed_expression).__name__}"
393+
)
360394
cte_select_expression = parsed_expression
361395
except SQLGlotParseError as e:
362-
self._raise_sql_builder_error(f"Failed to parse CTE query string: {e!s}", e)
363-
except Exception as e:
364-
msg = f"An unexpected error occurred while parsing CTE query string: {e!s}"
365-
self._raise_sql_builder_error(msg, e)
396+
self._raise_cte_parse_error(e)
366397
elif isinstance(query, exp.Select):
367398
cte_select_expression = query
368399
else:
369-
msg = f"Invalid query type for CTE: {type(query).__name__}"
370-
self._raise_sql_builder_error(msg)
400+
self._raise_cte_query_error(alias, f"invalid query type: {type(query).__name__}")
371401

372402
self._with_ctes[alias] = exp.CTE(this=cte_select_expression, alias=exp.to_table(alias))
373403
return self
@@ -438,10 +468,9 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
438468
optimized = optimize(
439469
expression, schema=self.schema, dialect=self.dialect_name, optimizer_settings=optimizer_settings
440470
)
441-
442471
cache.put("optimized", cache_key, optimized)
443-
444472
except Exception:
473+
logger.debug("Expression optimization failed, using original expression")
445474
return expression
446475
else:
447476
return optimized
@@ -482,18 +511,7 @@ def _to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL":
482511
"""
483512
safe_query = self.build()
484513

485-
if isinstance(safe_query.parameters, dict):
486-
kwargs = safe_query.parameters
487-
parameters: Optional[tuple[Any, ...]] = None
488-
else:
489-
kwargs = None
490-
parameters = (
491-
safe_query.parameters
492-
if isinstance(safe_query.parameters, tuple)
493-
else tuple(safe_query.parameters)
494-
if safe_query.parameters
495-
else None
496-
)
514+
kwargs, parameters = self._extract_statement_parameters(safe_query.parameters)
497515

498516
if config is None:
499517
config = StatementConfig(
@@ -521,6 +539,28 @@ def _to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL":
521539
return SQL(sql_string, *parameters, statement_config=config)
522540
return SQL(sql_string, statement_config=config)
523541

542+
def _extract_statement_parameters(
543+
self, raw_parameters: Any
544+
) -> "tuple[Optional[dict[str, Any]], Optional[tuple[Any, ...]]]":
545+
"""Extract parameters for SQL statement creation.
546+
547+
Args:
548+
raw_parameters: Raw parameter data from SafeQuery
549+
550+
Returns:
551+
Tuple of (kwargs, parameters) for SQL statement construction
552+
"""
553+
if isinstance(raw_parameters, dict):
554+
return raw_parameters, None
555+
556+
if isinstance(raw_parameters, tuple):
557+
return None, raw_parameters
558+
559+
if raw_parameters:
560+
return None, tuple(raw_parameters)
561+
562+
return None, None
563+
524564
def __str__(self) -> str:
525565
"""Return the SQL string representation of the query.
526566

0 commit comments

Comments
 (0)