|
14 | 14 | from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary |
15 | 15 | from sqlspec.adapters.adbc.type_converter import ADBCTypeConverter |
16 | 16 | from sqlspec.core.cache import get_cache_config |
17 | | -from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig |
| 17 | +from sqlspec.core.parameters import ( |
| 18 | + ParameterProfile, |
| 19 | + ParameterStyle, |
| 20 | + ParameterStyleConfig, |
| 21 | + ParameterValidator, |
| 22 | + validate_parameter_alignment, |
| 23 | +) |
18 | 24 | from sqlspec.core.result import create_arrow_result |
19 | 25 | from sqlspec.core.statement import SQL, StatementConfig |
20 | 26 | from sqlspec.driver import SyncDriverAdapterBase |
|
69 | 75 | "snowflake": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC]), |
70 | 76 | } |
71 | 77 |
|
72 | | - |
73 | | -def _count_placeholders(expression: Any) -> int: |
74 | | - """Count the number of unique parameter placeholders in a SQLGlot expression. |
75 | | -
|
76 | | - For PostgreSQL ($1, $2) style: counts highest numbered parameter (e.g., $1, $1, $2 = 2) |
77 | | - For QMARK (?) style: counts total occurrences (each ? is a separate parameter) |
78 | | - For named (:name) style: counts unique parameter names |
79 | | -
|
80 | | - Args: |
81 | | - expression: SQLGlot AST expression |
82 | | -
|
83 | | - Returns: |
84 | | - Number of unique parameter placeholders expected |
85 | | - """ |
86 | | - numeric_params = set() # For $1, $2 style |
87 | | - qmark_count = 0 # For ? style |
88 | | - named_params = set() # For :name style |
89 | | - |
90 | | - def count_node(node: Any) -> Any: |
91 | | - nonlocal qmark_count |
92 | | - if isinstance(node, exp.Parameter): |
93 | | - # PostgreSQL style: $1, $2, etc. |
94 | | - param_str = str(node) |
95 | | - if param_str.startswith("$") and param_str[1:].isdigit(): |
96 | | - numeric_params.add(int(param_str[1:])) |
97 | | - elif ":" in param_str: |
98 | | - # Named parameter: :name |
99 | | - named_params.add(param_str) |
100 | | - else: |
101 | | - # Other parameter formats |
102 | | - named_params.add(param_str) |
103 | | - elif isinstance(node, exp.Placeholder): |
104 | | - # QMARK style: ? |
105 | | - qmark_count += 1 |
106 | | - return node |
107 | | - |
108 | | - expression.transform(count_node) |
109 | | - |
110 | | - # Return the appropriate count based on parameter style detected |
111 | | - if numeric_params: |
112 | | - # PostgreSQL style: return highest numbered parameter |
113 | | - return max(numeric_params) |
114 | | - if named_params: |
115 | | - # Named parameters: return count of unique names |
116 | | - return len(named_params) |
117 | | - # QMARK style: return total count |
118 | | - return qmark_count |
| 78 | +_AST_PARAMETER_VALIDATOR: "ParameterValidator" = ParameterValidator() |
119 | 79 |
|
120 | 80 |
|
121 | 81 | def _is_execute_many_parameters(parameters: Any) -> bool: |
122 | 82 | """Check if parameters are in execute_many format (list/tuple of lists/tuples).""" |
123 | 83 | return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], (list, tuple)) |
124 | 84 |
|
125 | 85 |
|
126 | | -def _validate_parameter_counts(expression: Any, parameters: Any, dialect: str) -> None: |
127 | | - """Validate parameter count against placeholder count in SQL.""" |
128 | | - placeholder_count = _count_placeholders(expression) |
129 | | - is_execute_many = _is_execute_many_parameters(parameters) |
130 | | - |
131 | | - if is_execute_many: |
132 | | - # For execute_many, validate each inner parameter set |
133 | | - for i, param_set in enumerate(parameters): |
134 | | - param_count = len(param_set) if isinstance(param_set, (list, tuple)) else 0 |
135 | | - if param_count != placeholder_count: |
136 | | - msg = f"Parameter count mismatch in set {i}: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})" |
137 | | - raise SQLSpecError(msg) |
138 | | - else: |
139 | | - # For single execution, validate the parameter set directly |
140 | | - param_count = ( |
141 | | - len(parameters) |
142 | | - if isinstance(parameters, (list, tuple)) |
143 | | - else len(parameters) |
144 | | - if isinstance(parameters, dict) |
145 | | - else 0 |
146 | | - ) |
147 | | - |
148 | | - if param_count != placeholder_count: |
149 | | - msg = f"Parameter count mismatch: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})" |
150 | | - raise SQLSpecError(msg) |
151 | | - |
152 | | - |
153 | 86 | def _find_null_positions(parameters: Any) -> set[int]: |
154 | 87 | """Find positions of None values in parameters for single execution.""" |
155 | 88 | null_positions = set() |
@@ -187,14 +120,15 @@ def _adbc_ast_transformer(expression: Any, parameters: Any, dialect: str = "post |
187 | 120 | if not parameters: |
188 | 121 | return expression, parameters |
189 | 122 |
|
190 | | - # Validate parameter count before transformation |
191 | | - _validate_parameter_counts(expression, parameters, dialect) |
192 | | - |
193 | 123 | # For execute_many operations, skip AST transformation as different parameter |
194 | 124 | # sets may have None values in different positions, making transformation complex |
195 | 125 | if _is_execute_many_parameters(parameters): |
196 | 126 | return expression, parameters |
197 | 127 |
|
| 128 | + parameter_info = _AST_PARAMETER_VALIDATOR.extract_parameters(expression.sql(dialect=dialect)) |
| 129 | + parameter_profile = ParameterProfile(parameter_info) |
| 130 | + validate_parameter_alignment(parameter_profile, parameters) |
| 131 | + |
198 | 132 | # Find positions of None values for single execution |
199 | 133 | null_positions = _find_null_positions(parameters) |
200 | 134 | if not null_positions: |
|
0 commit comments