Skip to content

Commit 4fdfe90

Browse files
committed
checkpoint
1 parent 98409e8 commit 4fdfe90

File tree

16 files changed

+1297
-562
lines changed

16 files changed

+1297
-562
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.12.5"
20+
rev: "v0.12.7"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

sqlspec/_sql.py

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def any(values: Union[list[Any], exp.Expression, str]) -> exp.Expression:
625625
```
626626
"""
627627
if isinstance(values, list):
628-
literals = [exp.Literal.string(str(v)) if isinstance(v, str) else exp.Literal.number(v) for v in values]
628+
literals = [SQLFactory._to_literal(v) for v in values]
629629
return exp.Any(this=exp.Array(expressions=literals))
630630
if isinstance(values, str):
631631
# Parse as SQL
@@ -738,6 +738,28 @@ def round(column: Union[str, exp.Expression], decimals: int = 0) -> exp.Expressi
738738
# Conversion Functions
739739
# ===================
740740

741+
@staticmethod
742+
def _to_literal(value: Any) -> exp.Expression:
743+
"""Convert a Python value to a SQLGlot literal expression.
744+
745+
Uses SQLGlot's built-in exp.convert() function for optimal dialect-agnostic
746+
literal creation. Handles all Python primitive types correctly:
747+
- None -> exp.Null (renders as NULL)
748+
- bool -> exp.Boolean (renders as TRUE/FALSE or 1/0 based on dialect)
749+
- int/float -> exp.Literal with is_number=True
750+
- str -> exp.Literal with is_string=True
751+
- exp.Expression -> returned as-is (passthrough)
752+
753+
Args:
754+
value: Python value or SQLGlot expression to convert.
755+
756+
Returns:
757+
SQLGlot expression representing the literal value.
758+
"""
759+
if isinstance(value, exp.Expression):
760+
return value
761+
return exp.convert(value)
762+
741763
@staticmethod
742764
def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression, Any]) -> exp.Expression:
743765
"""Create a DECODE expression (Oracle-style conditional logic).
@@ -776,29 +798,14 @@ def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression,
776798
for i in range(0, len(args) - 1, 2):
777799
if i + 1 >= len(args):
778800
# Odd number of args means last one is default
779-
default = exp.Literal.string(str(args[i])) if not isinstance(args[i], exp.Expression) else args[i]
801+
default = SQLFactory._to_literal(args[i])
780802
break
781803

782804
search_val = args[i]
783805
result_val = args[i + 1]
784806

785-
if isinstance(search_val, str):
786-
search_expr = exp.Literal.string(search_val)
787-
elif isinstance(search_val, (int, float)):
788-
search_expr = exp.Literal.number(search_val)
789-
elif isinstance(search_val, exp.Expression):
790-
search_expr = search_val # type: ignore[assignment]
791-
else:
792-
search_expr = exp.Literal.string(str(search_val))
793-
794-
if isinstance(result_val, str):
795-
result_expr = exp.Literal.string(result_val)
796-
elif isinstance(result_val, (int, float)):
797-
result_expr = exp.Literal.number(result_val)
798-
elif isinstance(result_val, exp.Expression):
799-
result_expr = result_val # type: ignore[assignment]
800-
else:
801-
result_expr = exp.Literal.string(str(result_val))
807+
search_expr = SQLFactory._to_literal(search_val)
808+
result_expr = SQLFactory._to_literal(result_val)
802809

803810
condition = exp.EQ(this=col_expr, expression=search_expr)
804811
conditions.append(exp.When(this=condition, then=result_expr))
@@ -844,17 +851,44 @@ def nvl(column: Union[str, exp.Expression], substitute_value: Union[str, exp.Exp
844851
COALESCE expression equivalent to NVL.
845852
"""
846853
col_expr = exp.column(column) if isinstance(column, str) else column
854+
sub_expr = SQLFactory._to_literal(substitute_value)
855+
return exp.Coalesce(expressions=[col_expr, sub_expr])
847856

848-
if isinstance(substitute_value, str):
849-
sub_expr = exp.Literal.string(substitute_value)
850-
elif isinstance(substitute_value, (int, float)):
851-
sub_expr = exp.Literal.number(substitute_value)
852-
elif isinstance(substitute_value, exp.Expression):
853-
sub_expr = substitute_value # type: ignore[assignment]
854-
else:
855-
sub_expr = exp.Literal.string(str(substitute_value))
857+
@staticmethod
858+
def nvl2(
859+
column: Union[str, exp.Expression],
860+
value_if_not_null: Union[str, exp.Expression, Any],
861+
value_if_null: Union[str, exp.Expression, Any],
862+
) -> exp.Expression:
863+
"""Create an NVL2 (Oracle-style) expression using CASE.
856864
857-
return exp.Coalesce(expressions=[col_expr, sub_expr])
865+
NVL2 returns value_if_not_null if column is not NULL,
866+
otherwise returns value_if_null.
867+
868+
Args:
869+
column: Column to check for NULL.
870+
value_if_not_null: Value to use if column is NOT NULL.
871+
value_if_null: Value to use if column is NULL.
872+
873+
Returns:
874+
CASE expression equivalent to NVL2.
875+
876+
Example:
877+
```python
878+
# NVL2(salary, 'Has Salary', 'No Salary')
879+
sql.nvl2("salary", "Has Salary", "No Salary")
880+
```
881+
"""
882+
col_expr = exp.column(column) if isinstance(column, str) else column
883+
not_null_expr = SQLFactory._to_literal(value_if_not_null)
884+
null_expr = SQLFactory._to_literal(value_if_null)
885+
886+
# Create CASE WHEN column IS NOT NULL THEN value_if_not_null ELSE value_if_null END
887+
is_null = exp.Is(this=col_expr, expression=exp.Null())
888+
condition = exp.Not(this=is_null)
889+
when_clause = exp.If(this=condition, true=not_null_expr)
890+
891+
return exp.Case(ifs=[when_clause], default=null_expr)
858892

859893
# ===================
860894
# Bulk Operations
@@ -1057,7 +1091,7 @@ class Case:
10571091

10581092
def __init__(self) -> None:
10591093
"""Initialize the CASE expression builder."""
1060-
self._conditions: list[exp.When] = []
1094+
self._conditions: list[exp.If] = []
10611095
self._default: Optional[exp.Expression] = None
10621096

10631097
def when(self, condition: Union[str, exp.Expression], value: Union[str, exp.Expression, Any]) -> "Case":
@@ -1071,17 +1105,10 @@ def when(self, condition: Union[str, exp.Expression], value: Union[str, exp.Expr
10711105
Self for method chaining.
10721106
"""
10731107
cond_expr = exp.maybe_parse(condition) or exp.column(condition) if isinstance(condition, str) else condition
1108+
val_expr = SQLFactory._to_literal(value)
10741109

1075-
if isinstance(value, str):
1076-
val_expr = exp.Literal.string(value)
1077-
elif isinstance(value, (int, float)):
1078-
val_expr = exp.Literal.number(value)
1079-
elif isinstance(value, exp.Expression):
1080-
val_expr = value # type: ignore[assignment]
1081-
else:
1082-
val_expr = exp.Literal.string(str(value))
1083-
1084-
when_clause = exp.When(this=cond_expr, then=val_expr)
1110+
# SQLGlot uses exp.If for CASE WHEN clauses, not exp.When
1111+
when_clause = exp.If(this=cond_expr, true=val_expr)
10851112
self._conditions.append(when_clause)
10861113
return self
10871114

@@ -1094,14 +1121,7 @@ def else_(self, value: Union[str, exp.Expression, Any]) -> "Case":
10941121
Returns:
10951122
Self for method chaining.
10961123
"""
1097-
if isinstance(value, str):
1098-
self._default = exp.Literal.string(value)
1099-
elif isinstance(value, (int, float)):
1100-
self._default = exp.Literal.number(value)
1101-
elif isinstance(value, exp.Expression):
1102-
self._default = value
1103-
else:
1104-
self._default = exp.Literal.string(str(value))
1124+
self._default = SQLFactory._to_literal(value)
11051125
return self
11061126

11071127
def end(self) -> exp.Expression:

sqlspec/adapters/adbc/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing_extensions import NotRequired
88

99
from sqlspec.adapters.adbc._types import AdbcConnection
10-
from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, create_adbc_statement_config
10+
from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, get_adbc_statement_config
1111
from sqlspec.config import NoPoolSyncConfig
1212
from sqlspec.exceptions import ImproperConfigurationError
1313
from sqlspec.statement.sql import StatementConfig
@@ -104,7 +104,7 @@ def __init__(
104104
if statement_config is None:
105105
# Detect dialect and create appropriate config
106106
detected_dialect = str(self._get_dialect() or "sqlite")
107-
statement_config = create_adbc_statement_config(detected_dialect)
107+
statement_config = get_adbc_statement_config(detected_dialect)
108108

109109
super().__init__(
110110
connection_config=self.connection_config,
@@ -310,7 +310,7 @@ def session_manager() -> "Generator[AdbcDriver, None, None]":
310310
final_statement_config = (
311311
statement_config
312312
or self.statement_config
313-
or create_adbc_statement_config(str(self._get_dialect() or "sqlite"))
313+
or get_adbc_statement_config(str(self._get_dialect() or "sqlite"))
314314
)
315315
yield self.driver_type(connection=connection, statement_config=final_statement_config)
316316

0 commit comments

Comments
 (0)