|
3 | 3 | Provides abstract base classes and core functionality for SQL query builders. |
4 | 4 | """ |
5 | 5 |
|
| 6 | +import hashlib |
| 7 | +import uuid |
6 | 8 | from abc import ABC, abstractmethod |
7 | 9 | from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union, cast |
8 | 10 |
|
|
19 | 21 | from sqlspec.core.statement import SQL, StatementConfig |
20 | 22 | from sqlspec.exceptions import SQLBuilderError |
21 | 23 | from sqlspec.utils.logging import get_logger |
22 | | -from sqlspec.utils.type_guards import has_expression_and_parameters, has_sql_method, has_with_method |
| 24 | +from sqlspec.utils.type_guards import has_expression_and_parameters, has_sql_method, has_with_method, is_expression |
23 | 25 |
|
24 | 26 | if TYPE_CHECKING: |
25 | 27 | from sqlspec.core.result import SQLResult |
26 | 28 |
|
27 | 29 | __all__ = ("QueryBuilder", "SafeQuery") |
28 | 30 |
|
| 31 | +MAX_PARAMETER_COLLISION_ATTEMPTS = 1000 |
| 32 | + |
29 | 33 | logger = get_logger(__name__) |
30 | 34 |
|
31 | 35 |
|
@@ -104,13 +108,9 @@ def set_expression(self, expression: exp.Expression) -> None: |
104 | 108 |
|
105 | 109 | Args: |
106 | 110 | expression: SQLGlot expression to set |
107 | | -
|
108 | | - Raises: |
109 | | - TypeError: If expression is not a SQLGlot Expression |
110 | 111 | """ |
111 | | - if not isinstance(expression, exp.Expression): |
112 | | - msg = f"Expected Expression, got {type(expression)}" |
113 | | - raise TypeError(msg) |
| 112 | + if not is_expression(expression): |
| 113 | + self._raise_invalid_expression_type(expression) |
114 | 114 | self._expression = expression |
115 | 115 |
|
116 | 116 | def has_expression(self) -> bool: |
@@ -151,6 +151,46 @@ def _raise_sql_builder_error(message: str, cause: Optional[BaseException] = None |
151 | 151 | """ |
152 | 152 | raise SQLBuilderError(message) from cause |
153 | 153 |
|
| 154 | + @staticmethod |
| 155 | + def _raise_invalid_expression_type(expression: Any) -> NoReturn: |
| 156 | + """Raise error for invalid expression type. |
| 157 | +
|
| 158 | + Args: |
| 159 | + expression: The invalid expression object |
| 160 | +
|
| 161 | + Raises: |
| 162 | + TypeError: Always raised for type mismatch |
| 163 | + """ |
| 164 | + msg = f"Expected Expression, got {type(expression)}" |
| 165 | + raise TypeError(msg) |
| 166 | + |
| 167 | + @staticmethod |
| 168 | + def _raise_cte_query_error(alias: str, message: str) -> NoReturn: |
| 169 | + """Raise error for CTE query issues. |
| 170 | +
|
| 171 | + Args: |
| 172 | + alias: CTE alias name |
| 173 | + message: Specific error message |
| 174 | +
|
| 175 | + Raises: |
| 176 | + SQLBuilderError: Always raised for CTE errors |
| 177 | + """ |
| 178 | + msg = f"CTE '{alias}': {message}" |
| 179 | + raise SQLBuilderError(msg) |
| 180 | + |
| 181 | + @staticmethod |
| 182 | + def _raise_cte_parse_error(cause: BaseException) -> NoReturn: |
| 183 | + """Raise error for CTE parsing failures. |
| 184 | +
|
| 185 | + Args: |
| 186 | + cause: The original parsing exception |
| 187 | +
|
| 188 | + Raises: |
| 189 | + SQLBuilderError: Always raised with chained cause |
| 190 | + """ |
| 191 | + msg = f"Failed to parse CTE query: {cause!s}" |
| 192 | + raise SQLBuilderError(msg) from cause |
| 193 | + |
154 | 194 | def _add_parameter(self, value: Any, context: Optional[str] = None) -> str: |
155 | 195 | """Adds a parameter to the query and returns its placeholder name. |
156 | 196 |
|
@@ -229,13 +269,11 @@ def _generate_unique_parameter_name(self, base_name: str) -> str: |
229 | 269 | if base_name not in self._parameters: |
230 | 270 | return base_name |
231 | 271 |
|
232 | | - for i in range(1, 1000): |
| 272 | + for i in range(1, MAX_PARAMETER_COLLISION_ATTEMPTS): |
233 | 273 | name = f"{base_name}_{i}" |
234 | 274 | if name not in self._parameters: |
235 | 275 | return name |
236 | 276 |
|
237 | | - import uuid |
238 | | - |
239 | 277 | return f"{base_name}_{uuid.uuid4().hex[:8]}" |
240 | 278 |
|
241 | 279 | def _merge_cte_parameters(self, cte_name: str, parameters: dict[str, Any]) -> dict[str, str]: |
@@ -284,8 +322,6 @@ def _generate_builder_cache_key(self, config: "Optional[StatementConfig]" = None |
284 | 322 | Returns: |
285 | 323 | A unique cache key representing the builder state and configuration |
286 | 324 | """ |
287 | | - import hashlib |
288 | | - |
289 | 325 | dialect_name: str = self.dialect_name or "default" |
290 | 326 |
|
291 | 327 | if self._expression is None: |
@@ -339,35 +375,29 @@ def with_cte(self: Self, alias: str, query: "Union[QueryBuilder, exp.Select, str |
339 | 375 | if isinstance(query, QueryBuilder): |
340 | 376 | query_expr = query.get_expression() |
341 | 377 | if query_expr is None: |
342 | | - self._raise_sql_builder_error("CTE query builder has no expression.") |
| 378 | + self._raise_cte_query_error(alias, "query builder has no expression") |
343 | 379 | if not isinstance(query_expr, exp.Select): |
344 | | - msg = f"CTE query builder expression must be a Select, got {type(query_expr).__name__}." |
345 | | - self._raise_sql_builder_error(msg) |
| 380 | + self._raise_cte_query_error(alias, f"expression must be a Select, got {type(query_expr).__name__}") |
346 | 381 | cte_select_expression = query_expr |
347 | 382 | param_mapping = self._merge_cte_parameters(alias, query.parameters) |
348 | | - updated_expression = self._update_placeholders_in_expression(cte_select_expression, param_mapping) |
349 | | - if not isinstance(updated_expression, exp.Select): |
350 | | - msg = f"Updated CTE expression must be a Select, got {type(updated_expression).__name__}." |
351 | | - self._raise_sql_builder_error(msg) |
352 | | - cte_select_expression = updated_expression |
| 383 | + cte_select_expression = cast( |
| 384 | + "exp.Select", self._update_placeholders_in_expression(cte_select_expression, param_mapping) |
| 385 | + ) |
353 | 386 |
|
354 | 387 | elif isinstance(query, str): |
355 | 388 | try: |
356 | 389 | parsed_expression = sqlglot.parse_one(query, read=self.dialect_name) |
357 | 390 | if not isinstance(parsed_expression, exp.Select): |
358 | | - msg = f"CTE query string must parse to a SELECT statement, got {type(parsed_expression).__name__}." |
359 | | - self._raise_sql_builder_error(msg) |
| 391 | + self._raise_cte_query_error( |
| 392 | + alias, f"query string must parse to SELECT, got {type(parsed_expression).__name__}" |
| 393 | + ) |
360 | 394 | cte_select_expression = parsed_expression |
361 | 395 | except SQLGlotParseError as e: |
362 | | - self._raise_sql_builder_error(f"Failed to parse CTE query string: {e!s}", e) |
363 | | - except Exception as e: |
364 | | - msg = f"An unexpected error occurred while parsing CTE query string: {e!s}" |
365 | | - self._raise_sql_builder_error(msg, e) |
| 396 | + self._raise_cte_parse_error(e) |
366 | 397 | elif isinstance(query, exp.Select): |
367 | 398 | cte_select_expression = query |
368 | 399 | else: |
369 | | - msg = f"Invalid query type for CTE: {type(query).__name__}" |
370 | | - self._raise_sql_builder_error(msg) |
| 400 | + self._raise_cte_query_error(alias, f"invalid query type: {type(query).__name__}") |
371 | 401 |
|
372 | 402 | self._with_ctes[alias] = exp.CTE(this=cte_select_expression, alias=exp.to_table(alias)) |
373 | 403 | return self |
@@ -438,10 +468,9 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression: |
438 | 468 | optimized = optimize( |
439 | 469 | expression, schema=self.schema, dialect=self.dialect_name, optimizer_settings=optimizer_settings |
440 | 470 | ) |
441 | | - |
442 | 471 | cache.put("optimized", cache_key, optimized) |
443 | | - |
444 | 472 | except Exception: |
| 473 | + logger.debug("Expression optimization failed, using original expression") |
445 | 474 | return expression |
446 | 475 | else: |
447 | 476 | return optimized |
@@ -482,18 +511,7 @@ def _to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL": |
482 | 511 | """ |
483 | 512 | safe_query = self.build() |
484 | 513 |
|
485 | | - if isinstance(safe_query.parameters, dict): |
486 | | - kwargs = safe_query.parameters |
487 | | - parameters: Optional[tuple[Any, ...]] = None |
488 | | - else: |
489 | | - kwargs = None |
490 | | - parameters = ( |
491 | | - safe_query.parameters |
492 | | - if isinstance(safe_query.parameters, tuple) |
493 | | - else tuple(safe_query.parameters) |
494 | | - if safe_query.parameters |
495 | | - else None |
496 | | - ) |
| 514 | + kwargs, parameters = self._extract_statement_parameters(safe_query.parameters) |
497 | 515 |
|
498 | 516 | if config is None: |
499 | 517 | config = StatementConfig( |
@@ -521,6 +539,28 @@ def _to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL": |
521 | 539 | return SQL(sql_string, *parameters, statement_config=config) |
522 | 540 | return SQL(sql_string, statement_config=config) |
523 | 541 |
|
| 542 | + def _extract_statement_parameters( |
| 543 | + self, raw_parameters: Any |
| 544 | + ) -> "tuple[Optional[dict[str, Any]], Optional[tuple[Any, ...]]]": |
| 545 | + """Extract parameters for SQL statement creation. |
| 546 | +
|
| 547 | + Args: |
| 548 | + raw_parameters: Raw parameter data from SafeQuery |
| 549 | +
|
| 550 | + Returns: |
| 551 | + Tuple of (kwargs, parameters) for SQL statement construction |
| 552 | + """ |
| 553 | + if isinstance(raw_parameters, dict): |
| 554 | + return raw_parameters, None |
| 555 | + |
| 556 | + if isinstance(raw_parameters, tuple): |
| 557 | + return None, raw_parameters |
| 558 | + |
| 559 | + if raw_parameters: |
| 560 | + return None, tuple(raw_parameters) |
| 561 | + |
| 562 | + return None, None |
| 563 | + |
524 | 564 | def __str__(self) -> str: |
525 | 565 | """Return the SQL string representation of the query. |
526 | 566 |
|
|
0 commit comments