Skip to content

Commit 4a743af

Browse files
authored
feat: refactor builder code to reduce duplication (#74)
Refactor the builder code to minimize duplication by introducing utility functions that check for specific attributes in SQL objects. This enhances code readability and maintainability.
1 parent 6eae415 commit 4a743af

File tree

9 files changed

+303
-280
lines changed

9 files changed

+303
-280
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.12.11"
20+
rev: "v0.12.12"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

sqlspec/builder/_base.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from sqlspec.core.statement import SQL, StatementConfig
2020
from sqlspec.exceptions import SQLBuilderError
2121
from sqlspec.utils.logging import get_logger
22-
from sqlspec.utils.type_guards import has_sql_method, has_with_method
22+
from sqlspec.utils.type_guards import has_expression_and_parameters, has_sql_method, has_with_method
2323

2424
if TYPE_CHECKING:
2525
from sqlspec.core.result import SQLResult
@@ -455,7 +455,7 @@ def _to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL":
455455

456456
if isinstance(safe_query.parameters, dict):
457457
kwargs = safe_query.parameters
458-
parameters: Optional[tuple] = None
458+
parameters: Optional[tuple[Any, ...]] = None
459459
else:
460460
kwargs = None
461461
parameters = (
@@ -479,7 +479,7 @@ def _to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL":
479479
config.dialect is not None
480480
and config.dialect != safe_query.dialect
481481
and self._expression is not None
482-
and hasattr(self._expression, "sql")
482+
and has_sql_method(self._expression)
483483
):
484484
try:
485485
sql_string = self._expression.sql(dialect=config.dialect, pretty=True)
@@ -498,26 +498,34 @@ def __str__(self) -> str:
498498
Returns:
499499
str: The SQL string for this query.
500500
"""
501-
try:
502-
return self.build().sql
503-
except Exception:
504-
return super().__str__()
501+
return self.build().sql
505502

506503
@property
507504
def dialect_name(self) -> "Optional[str]":
508505
"""Returns the name of the dialect, if set."""
509506
if isinstance(self.dialect, str):
510507
return self.dialect
511-
if self.dialect is not None:
512-
if isinstance(self.dialect, type) and issubclass(self.dialect, Dialect):
513-
return self.dialect.__name__.lower()
514-
if isinstance(self.dialect, Dialect):
515-
return type(self.dialect).__name__.lower()
516-
try:
517-
return self.dialect.__name__.lower()
518-
except AttributeError:
519-
pass
520-
return None
508+
if self.dialect is None:
509+
return None
510+
if isinstance(self.dialect, type) and issubclass(self.dialect, Dialect):
511+
return self.dialect.__name__.lower()
512+
if isinstance(self.dialect, Dialect):
513+
return type(self.dialect).__name__.lower()
514+
return getattr(self.dialect, "__name__", str(self.dialect)).lower()
515+
516+
def _merge_sql_object_parameters(self, sql_obj: Any) -> None:
517+
"""Merge parameters from a SQL object into the builder.
518+
519+
Args:
520+
sql_obj: Object with parameters attribute containing parameter mappings
521+
"""
522+
if not has_expression_and_parameters(sql_obj):
523+
return
524+
525+
sql_parameters = getattr(sql_obj, "parameters", {})
526+
for param_name, param_value in sql_parameters.items():
527+
unique_name = self._generate_unique_parameter_name(param_name)
528+
self.add_parameter(param_value, name=unique_name)
521529

522530
@property
523531
def parameters(self) -> dict[str, Any]:

sqlspec/builder/_ddl.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from sqlspec.builder._base import QueryBuilder, SafeQuery
1414
from sqlspec.core.result import SQLResult
15+
from sqlspec.utils.type_guards import has_sqlglot_expression, has_with_method
1516

1617
if TYPE_CHECKING:
1718
from sqlspec.builder._column import ColumnExpression
@@ -436,8 +437,8 @@ def check_constraint(self, condition: Union[str, "ColumnExpression"], name: "Opt
436437
self._raise_sql_builder_error("Check constraint must have a condition")
437438

438439
condition_str: str
439-
if hasattr(condition, "sqlglot_expression"):
440-
sqlglot_expr = getattr(condition, "sqlglot_expression", None)
440+
if has_sqlglot_expression(condition):
441+
sqlglot_expr = condition.sqlglot_expression
441442
condition_str = sqlglot_expr.sql(dialect=self.dialect) if sqlglot_expr else str(condition)
442443
else:
443444
condition_str = str(condition)
@@ -970,15 +971,15 @@ def _create_base_expression(self) -> exp.Expression:
970971

971972
if isinstance(self._select_query, SQL):
972973
select_expr = self._select_query.expression
973-
select_parameters = getattr(self._select_query, "parameters", None)
974+
select_parameters = self._select_query.parameters
974975
elif isinstance(self._select_query, Select):
975-
select_expr = getattr(self._select_query, "_expression", None)
976-
select_parameters = getattr(self._select_query, "_parameters", None)
976+
select_expr = self._select_query._expression
977+
select_parameters = self._select_query._parameters
977978

978-
with_ctes = getattr(self._select_query, "_with_ctes", {})
979+
with_ctes = self._select_query._with_ctes
979980
if with_ctes and select_expr and isinstance(select_expr, exp.Select):
980981
for alias, cte in with_ctes.items():
981-
if hasattr(select_expr, "with_"):
982+
if has_with_method(select_expr):
982983
select_expr = select_expr.with_(cte.this, as_=alias, copy=False)
983984
elif isinstance(self._select_query, str):
984985
select_expr = exp.maybe_parse(self._select_query)
@@ -1097,10 +1098,10 @@ def _create_base_expression(self) -> exp.Expression:
10971098

10981099
if isinstance(self._select_query, SQL):
10991100
select_expr = self._select_query.expression
1100-
select_parameters = getattr(self._select_query, "parameters", None)
1101+
select_parameters = self._select_query.parameters
11011102
elif isinstance(self._select_query, Select):
1102-
select_expr = getattr(self._select_query, "_expression", None)
1103-
select_parameters = getattr(self._select_query, "_parameters", None)
1103+
select_expr = self._select_query._expression
1104+
select_parameters = self._select_query._parameters
11041105
elif isinstance(self._select_query, str):
11051106
select_expr = exp.maybe_parse(self._select_query)
11061107
select_parameters = None
@@ -1195,10 +1196,10 @@ def _create_base_expression(self) -> exp.Expression:
11951196

11961197
if isinstance(self._select_query, SQL):
11971198
select_expr = self._select_query.expression
1198-
select_parameters = getattr(self._select_query, "parameters", None)
1199+
select_parameters = self._select_query.parameters
11991200
elif isinstance(self._select_query, Select):
1200-
select_expr = getattr(self._select_query, "_expression", None)
1201-
select_parameters = getattr(self._select_query, "_parameters", None)
1201+
select_expr = self._select_query._expression
1202+
select_parameters = self._select_query._parameters
12021203
elif isinstance(self._select_query, str):
12031204
select_expr = exp.maybe_parse(self._select_query)
12041205
select_parameters = None

sqlspec/builder/_insert.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sqlspec.builder.mixins import InsertFromSelectMixin, InsertIntoClauseMixin, InsertValuesMixin, ReturningClauseMixin
1414
from sqlspec.core.result import SQLResult
1515
from sqlspec.exceptions import SQLBuilderError
16+
from sqlspec.utils.type_guards import has_expression_and_sql
1617

1718
if TYPE_CHECKING:
1819
from collections.abc import Mapping, Sequence
@@ -124,12 +125,9 @@ def values(self, *values: Any, **kwargs: Any) -> "Self":
124125
return self.values_from_dict(kwargs)
125126

126127
if len(values) == 1:
127-
try:
128-
values_0 = values[0]
129-
if hasattr(values_0, "items"):
130-
return self.values_from_dict(values_0)
131-
except (AttributeError, TypeError):
132-
pass
128+
values_0 = values[0]
129+
if hasattr(values_0, "items") and hasattr(values_0, "keys"):
130+
return self.values_from_dict(values_0)
133131

134132
insert_expr = self._get_insert_expression()
135133

@@ -141,24 +139,18 @@ def values(self, *values: Any, **kwargs: Any) -> "Self":
141139
for i, value in enumerate(values):
142140
if isinstance(value, exp.Expression):
143141
value_placeholders.append(value)
144-
elif hasattr(value, "expression") and hasattr(value, "sql"):
142+
elif has_expression_and_sql(value):
145143
# Handle SQL objects (from sql.raw with parameters)
146144
expression = getattr(value, "expression", None)
147145
if expression is not None and isinstance(expression, exp.Expression):
148146
# Merge parameters from SQL object into builder
149-
if hasattr(value, "parameters"):
150-
sql_parameters = getattr(value, "parameters", {})
151-
for param_name, param_value in sql_parameters.items():
152-
self.add_parameter(param_value, name=param_name)
147+
self._merge_sql_object_parameters(value)
153148
value_placeholders.append(expression)
154149
else:
155150
# If expression is None, fall back to parsing the raw SQL
156151
sql_text = getattr(value, "sql", "")
157152
# Merge parameters even when parsing raw SQL
158-
if hasattr(value, "parameters"):
159-
sql_parameters = getattr(value, "parameters", {})
160-
for param_name, param_value in sql_parameters.items():
161-
self.add_parameter(param_value, name=param_name)
153+
self._merge_sql_object_parameters(value)
162154
# Check if sql_text is callable (like Expression.sql method)
163155
if callable(sql_text):
164156
sql_text = str(value)
@@ -376,7 +368,7 @@ def do_update(self, **kwargs: Any) -> "Insert":
376368
# Create SET expressions for the UPDATE
377369
set_expressions = []
378370
for col, val in kwargs.items():
379-
if hasattr(val, "expression") and hasattr(val, "sql"):
371+
if has_expression_and_sql(val):
380372
# Handle SQL objects (from sql.raw with parameters)
381373
expression = getattr(val, "expression", None)
382374
if expression is not None and isinstance(expression, exp.Expression):

sqlspec/builder/_parsing_utils.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from sqlglot import exp, maybe_parse, parse_one
1111

1212
from sqlspec.core.parameters import ParameterStyle
13-
from sqlspec.utils.type_guards import has_expression_attr, has_parameter_builder
13+
from sqlspec.utils.type_guards import (
14+
has_expression_and_parameters,
15+
has_expression_and_sql,
16+
has_expression_attr,
17+
has_parameter_builder,
18+
)
1419

1520

1621
def parse_column_expression(
@@ -38,32 +43,29 @@ def parse_column_expression(
3843
return column_input
3944

4045
# Handle SQL objects (from sql.raw with parameters)
41-
if hasattr(column_input, "expression") and hasattr(column_input, "sql"):
46+
if has_expression_and_sql(column_input):
4247
# This is likely a SQL object
4348
expression = getattr(column_input, "expression", None)
4449
if expression is not None and isinstance(expression, exp.Expression):
4550
# Merge parameters from SQL object into builder if available
46-
if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"):
51+
if builder and has_expression_and_parameters(column_input) and hasattr(builder, "add_parameter"):
4752
sql_parameters = getattr(column_input, "parameters", {})
4853
for param_name, param_value in sql_parameters.items():
4954
builder.add_parameter(param_value, name=param_name)
5055
return cast("exp.Expression", expression)
5156
# If expression is None, fall back to parsing the raw SQL
5257
sql_text = getattr(column_input, "sql", "")
5358
# Merge parameters even when parsing raw SQL
54-
if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"):
59+
if builder and has_expression_and_parameters(column_input) and hasattr(builder, "add_parameter"):
5560
sql_parameters = getattr(column_input, "parameters", {})
5661
for param_name, param_value in sql_parameters.items():
5762
builder.add_parameter(param_value, name=param_name)
5863
return exp.maybe_parse(sql_text) or exp.column(str(sql_text))
5964

6065
if has_expression_attr(column_input):
61-
try:
62-
attr_value = column_input._expression
63-
if isinstance(attr_value, exp.Expression):
64-
return attr_value
65-
except AttributeError:
66-
pass
66+
attr_value = getattr(column_input, "_expression", None)
67+
if isinstance(attr_value, exp.Expression):
68+
return attr_value
6769

6870
return exp.maybe_parse(column_input) or exp.column(str(column_input))
6971

@@ -178,14 +180,10 @@ def parse_condition_expression(
178180
)
179181
condition_input = converted_condition
180182

181-
try:
182-
return exp.condition(condition_input)
183-
except Exception:
184-
try:
185-
parsed = exp.maybe_parse(condition_input) # type: ignore[var-annotated]
186-
return parsed or exp.condition(condition_input)
187-
except Exception:
188-
return exp.condition(condition_input)
183+
parsed: Optional[exp.Expression] = exp.maybe_parse(condition_input)
184+
if parsed:
185+
return parsed
186+
return exp.condition(condition_input)
189187

190188

191189
__all__ = ("parse_column_expression", "parse_condition_expression", "parse_order_expression", "parse_table_expression")

sqlspec/builder/_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def build(self) -> "SafeQuery":
176176
msg = "No UPDATE expression to build or expression is of the wrong type."
177177
raise SQLBuilderError(msg)
178178

179-
if getattr(self._expression, "this", None) is None:
179+
if self._expression.this is None:
180180
msg = "No table specified for UPDATE statement."
181181
raise SQLBuilderError(msg)
182182

0 commit comments

Comments
 (0)