Skip to content

Commit f597baf

Browse files
authored
feat: enhance sync_tools and migration infrastructure (#81)
Core infrastructure updates for async/sync interoperability: - Enhanced sync_tools with thread-local state tracking - Improved context manager wrapper for thread consistency - Better error handling in await_ function for running loops - Fixed type annotations across core modules - Added MigrationContext for runtime migration information - Enhanced migration runners with extension support These changes provide the foundation for data dictionary and litestar session features.
1 parent 85c55b5 commit f597baf

File tree

14 files changed

+754
-135
lines changed

14 files changed

+754
-135
lines changed

sqlspec/builder/_column.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from collections.abc import Iterable
8+
from datetime import date, datetime
89
from typing import Any, Optional, cast
910

1011
from sqlglot import exp
@@ -67,33 +68,53 @@ def __init__(self, name: str, table: Optional[str] = None) -> None:
6768
else:
6869
self._expression = exp.Column(this=exp.Identifier(this=name))
6970

71+
def _convert_value(self, value: Any) -> exp.Expression:
72+
"""Convert a Python value to a SQLGlot expression.
73+
74+
Special handling for datetime objects to prevent SQLGlot from
75+
converting them to TIME_STR_TO_TIME function calls. Datetime
76+
objects should be passed as parameters, not converted to SQL functions.
77+
78+
Args:
79+
value: The value to convert
80+
81+
Returns:
82+
A SQLGlot expression representing the value
83+
"""
84+
if isinstance(value, (datetime, date)):
85+
# Create a Literal with the datetime value directly
86+
# This will be parameterized by the QueryBuilder's _parameterize_expression
87+
# Don't use exp.convert() which would create TIME_STR_TO_TIME
88+
return exp.Literal(this=value, is_string=False)
89+
return exp.convert(value)
90+
7091
def __eq__(self, other: object) -> ColumnExpression: # type: ignore[override]
7192
"""Equal to (==)."""
7293
if other is None:
7394
return ColumnExpression(exp.Is(this=self._expression, expression=exp.Null()))
74-
return ColumnExpression(exp.EQ(this=self._expression, expression=exp.convert(other)))
95+
return ColumnExpression(exp.EQ(this=self._expression, expression=self._convert_value(other)))
7596

7697
def __ne__(self, other: object) -> ColumnExpression: # type: ignore[override]
7798
"""Not equal to (!=)."""
7899
if other is None:
79100
return ColumnExpression(exp.Not(this=exp.Is(this=self._expression, expression=exp.Null())))
80-
return ColumnExpression(exp.NEQ(this=self._expression, expression=exp.convert(other)))
101+
return ColumnExpression(exp.NEQ(this=self._expression, expression=self._convert_value(other)))
81102

82103
def __gt__(self, other: Any) -> ColumnExpression:
83104
"""Greater than (>)."""
84-
return ColumnExpression(exp.GT(this=self._expression, expression=exp.convert(other)))
105+
return ColumnExpression(exp.GT(this=self._expression, expression=self._convert_value(other)))
85106

86107
def __ge__(self, other: Any) -> ColumnExpression:
87108
"""Greater than or equal (>=)."""
88-
return ColumnExpression(exp.GTE(this=self._expression, expression=exp.convert(other)))
109+
return ColumnExpression(exp.GTE(this=self._expression, expression=self._convert_value(other)))
89110

90111
def __lt__(self, other: Any) -> ColumnExpression:
91112
"""Less than (<)."""
92-
return ColumnExpression(exp.LT(this=self._expression, expression=exp.convert(other)))
113+
return ColumnExpression(exp.LT(this=self._expression, expression=self._convert_value(other)))
93114

94115
def __le__(self, other: Any) -> ColumnExpression:
95116
"""Less than or equal (<=)."""
96-
return ColumnExpression(exp.LTE(this=self._expression, expression=exp.convert(other)))
117+
return ColumnExpression(exp.LTE(this=self._expression, expression=self._convert_value(other)))
97118

98119
def __invert__(self) -> ColumnExpression:
99120
"""Apply NOT operator (~)."""
@@ -102,18 +123,20 @@ def __invert__(self) -> ColumnExpression:
102123
def like(self, pattern: str, escape: Optional[str] = None) -> ColumnExpression:
103124
"""SQL LIKE pattern matching."""
104125
if escape:
105-
like_expr = exp.Like(this=self._expression, expression=exp.convert(pattern), escape=exp.convert(escape))
126+
like_expr = exp.Like(
127+
this=self._expression, expression=self._convert_value(pattern), escape=self._convert_value(escape)
128+
)
106129
else:
107-
like_expr = exp.Like(this=self._expression, expression=exp.convert(pattern))
130+
like_expr = exp.Like(this=self._expression, expression=self._convert_value(pattern))
108131
return ColumnExpression(like_expr)
109132

110133
def ilike(self, pattern: str) -> ColumnExpression:
111134
"""Case-insensitive LIKE."""
112-
return ColumnExpression(exp.ILike(this=self._expression, expression=exp.convert(pattern)))
135+
return ColumnExpression(exp.ILike(this=self._expression, expression=self._convert_value(pattern)))
113136

114137
def in_(self, values: Iterable[Any]) -> ColumnExpression:
115138
"""SQL IN clause."""
116-
converted_values = [exp.convert(v) for v in values]
139+
converted_values = [self._convert_value(v) for v in values]
117140
return ColumnExpression(exp.In(this=self._expression, expressions=converted_values))
118141

119142
def not_in(self, values: Iterable[Any]) -> ColumnExpression:
@@ -122,7 +145,9 @@ def not_in(self, values: Iterable[Any]) -> ColumnExpression:
122145

123146
def between(self, start: Any, end: Any) -> ColumnExpression:
124147
"""SQL BETWEEN clause."""
125-
return ColumnExpression(exp.Between(this=self._expression, low=exp.convert(start), high=exp.convert(end)))
148+
return ColumnExpression(
149+
exp.Between(this=self._expression, low=self._convert_value(start), high=self._convert_value(end))
150+
)
126151

127152
def is_null(self) -> ColumnExpression:
128153
"""SQL IS NULL."""
@@ -142,12 +167,12 @@ def not_ilike(self, pattern: str) -> ColumnExpression:
142167

143168
def any_(self, values: Iterable[Any]) -> ColumnExpression:
144169
"""SQL = ANY(...) clause."""
145-
converted_values = [exp.convert(v) for v in values]
170+
converted_values = [self._convert_value(v) for v in values]
146171
return ColumnExpression(exp.EQ(this=self._expression, expression=exp.Any(expressions=converted_values)))
147172

148173
def not_any_(self, values: Iterable[Any]) -> ColumnExpression:
149174
"""SQL <> ANY(...) clause."""
150-
converted_values = [exp.convert(v) for v in values]
175+
converted_values = [self._convert_value(v) for v in values]
151176
return ColumnExpression(exp.NEQ(this=self._expression, expression=exp.Any(expressions=converted_values)))
152177

153178
def lower(self) -> "FunctionColumn":
@@ -186,14 +211,14 @@ def ceil(self) -> "FunctionColumn":
186211

187212
def substring(self, start: int, length: Optional[int] = None) -> "FunctionColumn":
188213
"""SQL SUBSTRING() function."""
189-
args = [exp.convert(start)]
214+
args = [self._convert_value(start)]
190215
if length is not None:
191-
args.append(exp.convert(length))
216+
args.append(self._convert_value(length))
192217
return FunctionColumn(exp.Substring(this=self._expression, expressions=args))
193218

194219
def coalesce(self, *values: Any) -> "FunctionColumn":
195220
"""SQL COALESCE() function."""
196-
expressions = [self._expression] + [exp.convert(v) for v in values]
221+
expressions = [self._expression] + [self._convert_value(v) for v in values]
197222
return FunctionColumn(exp.Coalesce(expressions=expressions))
198223

199224
def cast(self, data_type: str) -> "FunctionColumn":
@@ -272,22 +297,42 @@ class FunctionColumn:
272297
def __init__(self, expression: exp.Expression) -> None:
273298
self._expression = expression
274299

300+
def _convert_value(self, value: Any) -> exp.Expression:
301+
"""Convert a Python value to a SQLGlot expression.
302+
303+
Special handling for datetime objects to prevent SQLGlot from
304+
converting them to TIME_STR_TO_TIME function calls. Datetime
305+
objects should be passed as parameters, not converted to SQL functions.
306+
307+
Args:
308+
value: The value to convert
309+
310+
Returns:
311+
A SQLGlot expression representing the value
312+
"""
313+
if isinstance(value, (datetime, date)):
314+
# Create a Literal with the datetime value directly
315+
# This will be parameterized by the QueryBuilder's _parameterize_expression
316+
# Don't use exp.convert() which would create TIME_STR_TO_TIME
317+
return exp.Literal(this=value, is_string=False)
318+
return exp.convert(value)
319+
275320
def __eq__(self, other: object) -> ColumnExpression: # type: ignore[override]
276-
return ColumnExpression(exp.EQ(this=self._expression, expression=exp.convert(other)))
321+
return ColumnExpression(exp.EQ(this=self._expression, expression=self._convert_value(other)))
277322

278323
def __ne__(self, other: object) -> ColumnExpression: # type: ignore[override]
279-
return ColumnExpression(exp.NEQ(this=self._expression, expression=exp.convert(other)))
324+
return ColumnExpression(exp.NEQ(this=self._expression, expression=self._convert_value(other)))
280325

281326
def like(self, pattern: str) -> ColumnExpression:
282-
return ColumnExpression(exp.Like(this=self._expression, expression=exp.convert(pattern)))
327+
return ColumnExpression(exp.Like(this=self._expression, expression=self._convert_value(pattern)))
283328

284329
def ilike(self, pattern: str) -> ColumnExpression:
285330
"""Case-insensitive LIKE."""
286-
return ColumnExpression(exp.ILike(this=self._expression, expression=exp.convert(pattern)))
331+
return ColumnExpression(exp.ILike(this=self._expression, expression=self._convert_value(pattern)))
287332

288333
def in_(self, values: Iterable[Any]) -> ColumnExpression:
289334
"""SQL IN clause."""
290-
converted_values = [exp.convert(v) for v in values]
335+
converted_values = [self._convert_value(v) for v in values]
291336
return ColumnExpression(exp.In(this=self._expression, expressions=converted_values))
292337

293338
def not_in_(self, values: Iterable[Any]) -> ColumnExpression:
@@ -304,7 +349,9 @@ def not_ilike(self, pattern: str) -> ColumnExpression:
304349

305350
def between(self, start: Any, end: Any) -> ColumnExpression:
306351
"""SQL BETWEEN clause."""
307-
return ColumnExpression(exp.Between(this=self._expression, low=exp.convert(start), high=exp.convert(end)))
352+
return ColumnExpression(
353+
exp.Between(this=self._expression, low=self._convert_value(start), high=self._convert_value(end))
354+
)
308355

309356
def is_null(self) -> ColumnExpression:
310357
"""SQL IS NULL."""
@@ -316,12 +363,12 @@ def is_not_null(self) -> ColumnExpression:
316363

317364
def any_(self, values: Iterable[Any]) -> ColumnExpression:
318365
"""SQL = ANY(...) clause."""
319-
converted_values = [exp.convert(v) for v in values]
366+
converted_values = [self._convert_value(v) for v in values]
320367
return ColumnExpression(exp.EQ(this=self._expression, expression=exp.Any(expressions=converted_values)))
321368

322369
def not_any_(self, values: Iterable[Any]) -> ColumnExpression:
323370
"""SQL <> ANY(...) clause."""
324-
converted_values = [exp.convert(v) for v in values]
371+
converted_values = [self._convert_value(v) for v in values]
325372
return ColumnExpression(exp.NEQ(this=self._expression, expression=exp.Any(expressions=converted_values)))
326373

327374
def alias(self, alias_name: str) -> exp.Expression:

sqlspec/builder/_insert.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,22 +290,63 @@ def on_conflict_do_nothing(self, *columns: str) -> "Insert":
290290
return self.on_conflict(*columns).do_nothing()
291291

292292
def on_duplicate_key_update(self, **kwargs: Any) -> "Insert":
293-
"""Adds conflict resolution using the ON CONFLICT syntax (cross-database compatible).
293+
"""Adds MySQL-style ON DUPLICATE KEY UPDATE clause.
294294
295295
Args:
296-
**kwargs: Column-value pairs to update on conflict.
296+
**kwargs: Column-value pairs to update on duplicate key.
297297
298298
Returns:
299299
The current builder instance for method chaining.
300300
301301
Note:
302-
This method uses PostgreSQL-style ON CONFLICT syntax but SQLGlot will
303-
transpile it to the appropriate syntax for each database (MySQL's
304-
ON DUPLICATE KEY UPDATE, etc.).
302+
This method creates MySQL-specific ON DUPLICATE KEY UPDATE syntax.
303+
For PostgreSQL, use on_conflict() instead.
305304
"""
306305
if not kwargs:
307306
return self
308-
return self.on_conflict().do_update(**kwargs)
307+
308+
insert_expr = self._get_insert_expression()
309+
310+
# Create SET expressions for MySQL ON DUPLICATE KEY UPDATE
311+
set_expressions = []
312+
for col, val in kwargs.items():
313+
if has_expression_and_sql(val):
314+
# Handle SQL objects (from sql.raw with parameters)
315+
expression = getattr(val, "expression", None)
316+
if expression is not None and isinstance(expression, exp.Expression):
317+
# Merge parameters from SQL object into builder
318+
self._merge_sql_object_parameters(val)
319+
value_expr = expression
320+
else:
321+
# If expression is None, fall back to parsing the raw SQL
322+
sql_text = getattr(val, "sql", "")
323+
# Merge parameters even when parsing raw SQL
324+
self._merge_sql_object_parameters(val)
325+
# Check if sql_text is callable (like Expression.sql method)
326+
if callable(sql_text):
327+
sql_text = str(val)
328+
value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
329+
elif isinstance(val, exp.Expression):
330+
value_expr = val
331+
else:
332+
# Create parameter for regular values
333+
param_name = self._generate_unique_parameter_name(col)
334+
_, param_name = self.add_parameter(val, name=param_name)
335+
value_expr = exp.Placeholder(this=param_name)
336+
337+
set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
338+
339+
# For MySQL, create ON CONFLICT with duplicate=True flag
340+
# This tells SQLGlot to generate ON DUPLICATE KEY UPDATE
341+
on_conflict = exp.OnConflict(
342+
duplicate=True, # This flag makes it MySQL-specific
343+
action=exp.var("UPDATE"), # MySQL requires UPDATE action
344+
expressions=set_expressions or None,
345+
)
346+
347+
insert_expr.set("conflict", on_conflict)
348+
349+
return self
309350

310351

311352
class ConflictBuilder:

0 commit comments

Comments
 (0)