Skip to content

Commit c869b92

Browse files
authored
feat: core handling and cleanup (#223)
- Updated BigQuery driver to raise an error for positional parameters instead of logging a warning. - Removed the `_is_modifying_operation` method from DuckDB driver and replaced its usage with a direct call to `is_modifying_operation` on the statement. - Enhanced the caching mechanism in the core module to include pipeline metrics and reset functionality. - Introduced `OperationProfile` and `ParameterProfile` classes to encapsulate metadata about SQL operations and parameters. - Updated SQL processing to utilize shared pipeline for compilation, improving performance and consistency.
1 parent 4a0ba88 commit c869b92

File tree

15 files changed

+854
-332
lines changed

15 files changed

+854
-332
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,5 @@ requirements/*
6666
!requirements/example-feature
6767
!requirements/README.md
6868
!.claude/bootstrap.md
69+
.pre-commit-cache
70+
.gh-cache

sqlspec/adapters/adbc/driver.py

Lines changed: 12 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary
1515
from sqlspec.adapters.adbc.type_converter import ADBCTypeConverter
1616
from sqlspec.core.cache import get_cache_config
17-
from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
17+
from sqlspec.core.parameters import (
18+
ParameterProfile,
19+
ParameterStyle,
20+
ParameterStyleConfig,
21+
ParameterValidator,
22+
validate_parameter_alignment,
23+
)
1824
from sqlspec.core.result import create_arrow_result
1925
from sqlspec.core.statement import SQL, StatementConfig
2026
from sqlspec.driver import SyncDriverAdapterBase
@@ -69,87 +75,14 @@
6975
"snowflake": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC]),
7076
}
7177

72-
73-
def _count_placeholders(expression: Any) -> int:
74-
"""Count the number of unique parameter placeholders in a SQLGlot expression.
75-
76-
For PostgreSQL ($1, $2) style: counts highest numbered parameter (e.g., $1, $1, $2 = 2)
77-
For QMARK (?) style: counts total occurrences (each ? is a separate parameter)
78-
For named (:name) style: counts unique parameter names
79-
80-
Args:
81-
expression: SQLGlot AST expression
82-
83-
Returns:
84-
Number of unique parameter placeholders expected
85-
"""
86-
numeric_params = set() # For $1, $2 style
87-
qmark_count = 0 # For ? style
88-
named_params = set() # For :name style
89-
90-
def count_node(node: Any) -> Any:
91-
nonlocal qmark_count
92-
if isinstance(node, exp.Parameter):
93-
# PostgreSQL style: $1, $2, etc.
94-
param_str = str(node)
95-
if param_str.startswith("$") and param_str[1:].isdigit():
96-
numeric_params.add(int(param_str[1:]))
97-
elif ":" in param_str:
98-
# Named parameter: :name
99-
named_params.add(param_str)
100-
else:
101-
# Other parameter formats
102-
named_params.add(param_str)
103-
elif isinstance(node, exp.Placeholder):
104-
# QMARK style: ?
105-
qmark_count += 1
106-
return node
107-
108-
expression.transform(count_node)
109-
110-
# Return the appropriate count based on parameter style detected
111-
if numeric_params:
112-
# PostgreSQL style: return highest numbered parameter
113-
return max(numeric_params)
114-
if named_params:
115-
# Named parameters: return count of unique names
116-
return len(named_params)
117-
# QMARK style: return total count
118-
return qmark_count
78+
_AST_PARAMETER_VALIDATOR: "ParameterValidator" = ParameterValidator()
11979

12080

12181
def _is_execute_many_parameters(parameters: Any) -> bool:
12282
"""Check if parameters are in execute_many format (list/tuple of lists/tuples)."""
12383
return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], (list, tuple))
12484

12585

126-
def _validate_parameter_counts(expression: Any, parameters: Any, dialect: str) -> None:
127-
"""Validate parameter count against placeholder count in SQL."""
128-
placeholder_count = _count_placeholders(expression)
129-
is_execute_many = _is_execute_many_parameters(parameters)
130-
131-
if is_execute_many:
132-
# For execute_many, validate each inner parameter set
133-
for i, param_set in enumerate(parameters):
134-
param_count = len(param_set) if isinstance(param_set, (list, tuple)) else 0
135-
if param_count != placeholder_count:
136-
msg = f"Parameter count mismatch in set {i}: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
137-
raise SQLSpecError(msg)
138-
else:
139-
# For single execution, validate the parameter set directly
140-
param_count = (
141-
len(parameters)
142-
if isinstance(parameters, (list, tuple))
143-
else len(parameters)
144-
if isinstance(parameters, dict)
145-
else 0
146-
)
147-
148-
if param_count != placeholder_count:
149-
msg = f"Parameter count mismatch: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
150-
raise SQLSpecError(msg)
151-
152-
15386
def _find_null_positions(parameters: Any) -> set[int]:
15487
"""Find positions of None values in parameters for single execution."""
15588
null_positions = set()
@@ -187,14 +120,15 @@ def _adbc_ast_transformer(expression: Any, parameters: Any, dialect: str = "post
187120
if not parameters:
188121
return expression, parameters
189122

190-
# Validate parameter count before transformation
191-
_validate_parameter_counts(expression, parameters, dialect)
192-
193123
# For execute_many operations, skip AST transformation as different parameter
194124
# sets may have None values in different positions, making transformation complex
195125
if _is_execute_many_parameters(parameters):
196126
return expression, parameters
197127

128+
parameter_info = _AST_PARAMETER_VALIDATOR.extract_parameters(expression.sql(dialect=dialect))
129+
parameter_profile = ParameterProfile(parameter_info)
130+
validate_parameter_alignment(parameter_profile, parameters)
131+
198132
# Find positions of None values for single execution
199133
null_positions = _find_null_positions(parameters)
200134
if not null_positions:

sqlspec/adapters/asyncmy/config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def __init__(
117117
if "port" not in processed_pool_config:
118118
processed_pool_config["port"] = 3306
119119

120-
if statement_config is None:
120+
using_default_statement_config = statement_config is None
121+
if using_default_statement_config:
121122
statement_config = asyncmy_statement_config
122123

123124
processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
@@ -127,6 +128,14 @@ def __init__(
127128
if "json_deserializer" not in processed_driver_features:
128129
processed_driver_features["json_deserializer"] = from_json
129130

131+
if statement_config is None:
132+
statement_config = asyncmy_statement_config
133+
134+
json_serializer = processed_driver_features.get("json_serializer")
135+
if json_serializer is not None and using_default_statement_config:
136+
parameter_config = statement_config.parameter_config.with_json_serializers(json_serializer)
137+
statement_config = statement_config.replace(parameter_config=parameter_config)
138+
130139
super().__init__(
131140
pool_config=processed_pool_config,
132141
pool_instance=pool_instance,

sqlspec/adapters/asyncmy/driver.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from sqlspec.utils.serializers import to_json
3131

3232
if TYPE_CHECKING:
33-
from collections.abc import Callable
3433
from contextlib import AbstractAsyncContextManager
3534

3635
from sqlspec.adapters.asyncmy._types import AsyncmyConnection
@@ -243,82 +242,11 @@ def __init__(
243242
dialect="mysql",
244243
)
245244

246-
final_statement_config = self._apply_json_serializer_feature(final_statement_config, driver_features)
247-
248245
super().__init__(
249246
connection=connection, statement_config=final_statement_config, driver_features=driver_features
250247
)
251248
self._data_dictionary: AsyncDataDictionaryBase | None = None
252249

253-
@staticmethod
254-
def _clone_parameter_config(
255-
parameter_config: ParameterStyleConfig, type_coercion_map: "dict[type[Any], Callable[[Any], Any]]"
256-
) -> ParameterStyleConfig:
257-
"""Create a copy of the parameter configuration with updated coercion map.
258-
259-
Args:
260-
parameter_config: Existing parameter configuration to copy.
261-
type_coercion_map: Updated coercion mapping for parameter serialization.
262-
263-
Returns:
264-
ParameterStyleConfig with the updated type coercion map applied.
265-
"""
266-
267-
supported_execution_styles = (
268-
set(parameter_config.supported_execution_parameter_styles)
269-
if parameter_config.supported_execution_parameter_styles is not None
270-
else None
271-
)
272-
273-
return ParameterStyleConfig(
274-
default_parameter_style=parameter_config.default_parameter_style,
275-
supported_parameter_styles=set(parameter_config.supported_parameter_styles),
276-
supported_execution_parameter_styles=supported_execution_styles,
277-
default_execution_parameter_style=parameter_config.default_execution_parameter_style,
278-
type_coercion_map=type_coercion_map,
279-
has_native_list_expansion=parameter_config.has_native_list_expansion,
280-
needs_static_script_compilation=parameter_config.needs_static_script_compilation,
281-
allow_mixed_parameter_styles=parameter_config.allow_mixed_parameter_styles,
282-
preserve_parameter_format=parameter_config.preserve_parameter_format,
283-
preserve_original_params_for_many=parameter_config.preserve_original_params_for_many,
284-
output_transformer=parameter_config.output_transformer,
285-
ast_transformer=parameter_config.ast_transformer,
286-
)
287-
288-
@staticmethod
289-
def _apply_json_serializer_feature(
290-
statement_config: "StatementConfig", driver_features: "dict[str, Any] | None"
291-
) -> "StatementConfig":
292-
"""Apply driver-level JSON serializer customization to the statement config.
293-
294-
Args:
295-
statement_config: Base statement configuration for the driver.
296-
driver_features: Driver feature mapping provided via configuration.
297-
298-
Returns:
299-
StatementConfig with serializer adjustments applied when configured.
300-
"""
301-
302-
if not driver_features:
303-
return statement_config
304-
305-
serializer = driver_features.get("json_serializer")
306-
if serializer is None:
307-
return statement_config
308-
309-
parameter_config = statement_config.parameter_config
310-
type_coercion_map = dict(parameter_config.type_coercion_map)
311-
312-
def serialize_tuple(value: Any) -> Any:
313-
return serializer(list(value))
314-
315-
type_coercion_map[dict] = serializer
316-
type_coercion_map[list] = serializer
317-
type_coercion_map[tuple] = serialize_tuple
318-
319-
updated_parameter_config = AsyncmyDriver._clone_parameter_config(parameter_config, type_coercion_map)
320-
return statement_config.replace(parameter_config=updated_parameter_config)
321-
322250
def with_cursor(self, connection: "AsyncmyConnection") -> "AsyncmyCursor":
323251
"""Create cursor context manager for the connection.
324252

sqlspec/adapters/bigquery/driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ def _create_bq_parameters(
279279
raise SQLSpecError(msg)
280280

281281
elif isinstance(parameters, (list, tuple)):
282-
logger.warning("BigQuery received positional parameters instead of named parameters")
283-
return []
282+
msg = "BigQuery driver requires named parameters (e.g., @name); positional parameters are not supported"
283+
raise SQLSpecError(msg)
284284

285285
return bq_parameters
286286

sqlspec/adapters/duckdb/driver.py

Lines changed: 12 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import TYPE_CHECKING, Any, Final
66

77
import duckdb
8-
from sqlglot import exp
98

109
from sqlspec.adapters.duckdb.data_dictionary import DuckDBSyncDataDictionary
1110
from sqlspec.adapters.duckdb.type_converter import DuckDBTypeConverter
@@ -225,36 +224,20 @@ def __init__(
225224
statement_config = updated_config
226225

227226
if driver_features:
227+
param_config = statement_config.parameter_config
228228
json_serializer = driver_features.get("json_serializer")
229-
enable_uuid_conversion = driver_features.get("enable_uuid_conversion", True)
229+
if json_serializer:
230+
param_config = param_config.with_json_serializers(json_serializer, tuple_strategy="tuple")
230231

231-
if json_serializer or not enable_uuid_conversion:
232+
enable_uuid_conversion = driver_features.get("enable_uuid_conversion", True)
233+
if not enable_uuid_conversion:
232234
type_converter = DuckDBTypeConverter(enable_uuid_conversion=enable_uuid_conversion)
233-
type_coercion_map = dict(statement_config.parameter_config.type_coercion_map)
234-
235-
if json_serializer:
236-
type_coercion_map[dict] = json_serializer
237-
type_coercion_map[list] = json_serializer
238-
239-
if not enable_uuid_conversion:
240-
type_coercion_map[str] = type_converter.convert_if_detected
241-
242-
param_config = statement_config.parameter_config
243-
updated_param_config = ParameterStyleConfig(
244-
default_parameter_style=param_config.default_parameter_style,
245-
supported_parameter_styles=param_config.supported_parameter_styles,
246-
supported_execution_parameter_styles=param_config.supported_execution_parameter_styles,
247-
default_execution_parameter_style=param_config.default_execution_parameter_style,
248-
type_coercion_map=type_coercion_map,
249-
has_native_list_expansion=param_config.has_native_list_expansion,
250-
needs_static_script_compilation=param_config.needs_static_script_compilation,
251-
allow_mixed_parameter_styles=param_config.allow_mixed_parameter_styles,
252-
preserve_parameter_format=param_config.preserve_parameter_format,
253-
preserve_original_params_for_many=param_config.preserve_original_params_for_many,
254-
output_transformer=param_config.output_transformer,
255-
ast_transformer=param_config.ast_transformer,
256-
)
257-
statement_config = statement_config.replace(parameter_config=updated_param_config)
235+
type_coercion_map = dict(param_config.type_coercion_map)
236+
type_coercion_map[str] = type_converter.convert_if_detected
237+
param_config = param_config.replace(type_coercion_map=type_coercion_map)
238+
239+
if param_config is not statement_config.parameter_config:
240+
statement_config = statement_config.replace(parameter_config=param_config)
258241

259242
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
260243
self._data_dictionary: SyncDataDictionaryBase | None = None
@@ -294,26 +277,6 @@ def _try_special_handling(self, cursor: Any, statement: SQL) -> "SQLResult | Non
294277
_ = (cursor, statement)
295278
return None
296279

297-
def _is_modifying_operation(self, statement: SQL) -> bool:
298-
"""Check if the SQL statement modifies data.
299-
300-
Determines if a statement is an INSERT, UPDATE, or DELETE operation
301-
using AST analysis when available, falling back to text parsing.
302-
303-
Args:
304-
statement: SQL statement to analyze
305-
306-
Returns:
307-
True if the operation modifies data (INSERT/UPDATE/DELETE)
308-
"""
309-
310-
expression = statement.expression
311-
if expression and isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
312-
return True
313-
314-
sql_upper = statement.sql.strip().upper()
315-
return any(sql_upper.startswith(op) for op in MODIFYING_OPERATIONS)
316-
317280
def _execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult":
318281
"""Execute SQL script with statement splitting and parameter handling.
319282
@@ -359,7 +322,7 @@ def _execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult":
359322
if prepared_parameters:
360323
cursor.executemany(sql, prepared_parameters)
361324

362-
if self._is_modifying_operation(statement):
325+
if statement.is_modifying_operation():
363326
row_count = len(prepared_parameters)
364327
else:
365328
try:

sqlspec/core/cache.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from mypy_extensions import mypyc_attr
2020
from typing_extensions import TypeVar
2121

22+
from sqlspec.core.pipeline import get_statement_pipeline_metrics, reset_statement_pipeline_cache
2223
from sqlspec.utils.logging import get_logger
2324

2425
if TYPE_CHECKING:
@@ -40,6 +41,8 @@
4041
"get_cache",
4142
"get_cache_config",
4243
"get_default_cache",
44+
"get_pipeline_metrics",
45+
"reset_pipeline_registry",
4346
)
4447

4548
T = TypeVar("T")
@@ -768,3 +771,15 @@ def to_canonical(self) -> "tuple[Any, ...]":
768771
filter_objects.append(Filter(f.field_name, f.operation, f.value))
769772

770773
return canonicalize_filters(filter_objects)
774+
775+
776+
def get_pipeline_metrics() -> "list[dict[str, Any]]":
777+
"""Return metrics for the shared statement pipeline cache when enabled."""
778+
779+
return get_statement_pipeline_metrics()
780+
781+
782+
def reset_pipeline_registry() -> None:
783+
"""Clear shared statement pipeline caches and metrics."""
784+
785+
reset_statement_pipeline_cache()

0 commit comments

Comments
 (0)