Skip to content

Commit 6eae415

Browse files
authored
feat: correctly merge parameter names & adds or builder support (#73)
Implement unique naming for parameters in CTEs to prevent collisions. Introduce support for an `or` builder to improve query flexibility.
1 parent 1c70683 commit 6eae415

File tree

6 files changed

+1481
-15
lines changed

6 files changed

+1481
-15
lines changed

sqlspec/builder/_base.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,43 @@ def _generate_unique_parameter_name(self, base_name: str) -> str:
208208

209209
return f"{base_name}_{uuid.uuid4().hex[:8]}"
210210

211+
def _merge_cte_parameters(self, cte_name: str, parameters: dict[str, Any]) -> dict[str, str]:
212+
"""Merge CTE parameters with unique naming to prevent collisions.
213+
214+
Args:
215+
cte_name: The name of the CTE for parameter prefixing
216+
parameters: The CTE's parameter dictionary
217+
218+
Returns:
219+
Mapping of old parameter names to new unique names
220+
"""
221+
param_mapping = {}
222+
for old_name, value in parameters.items():
223+
new_name = self._generate_unique_parameter_name(f"{cte_name}_{old_name}")
224+
param_mapping[old_name] = new_name
225+
self.add_parameter(value, name=new_name)
226+
return param_mapping
227+
228+
def _update_placeholders_in_expression(
229+
self, expression: exp.Expression, param_mapping: dict[str, str]
230+
) -> exp.Expression:
231+
"""Update parameter placeholders in expression to use new names.
232+
233+
Args:
234+
expression: The SQLGlot expression to update
235+
param_mapping: Mapping of old parameter names to new names
236+
237+
Returns:
238+
Updated expression with new placeholder names
239+
"""
240+
241+
def placeholder_replacer(node: exp.Expression) -> exp.Expression:
242+
if isinstance(node, exp.Placeholder) and str(node.this) in param_mapping:
243+
return exp.Placeholder(this=param_mapping[str(node.this)])
244+
return node
245+
246+
return expression.transform(placeholder_replacer, copy=False)
247+
211248
def _generate_builder_cache_key(self, config: "Optional[StatementConfig]" = None) -> str:
212249
"""Generate cache key based on builder state and configuration.
213250
@@ -276,9 +313,12 @@ def with_cte(self: Self, alias: str, query: "Union[QueryBuilder, exp.Select, str
276313
msg = f"CTE query builder expression must be a Select, got {type(query._expression).__name__}."
277314
self._raise_sql_builder_error(msg)
278315
cte_select_expression = query._expression
279-
for p_name, p_value in query.parameters.items():
280-
unique_name = self._generate_unique_parameter_name(p_name)
281-
self.add_parameter(p_value, unique_name)
316+
param_mapping = self._merge_cte_parameters(alias, query.parameters)
317+
updated_expression = self._update_placeholders_in_expression(cte_select_expression, param_mapping)
318+
if not isinstance(updated_expression, exp.Select):
319+
msg = f"Updated CTE expression must be a Select, got {type(updated_expression).__name__}."
320+
self._raise_sql_builder_error(msg)
321+
cte_select_expression = updated_expression
282322

283323
elif isinstance(query, str):
284324
try:
@@ -297,7 +337,6 @@ def with_cte(self: Self, alias: str, query: "Union[QueryBuilder, exp.Select, str
297337
else:
298338
msg = f"Invalid query type for CTE: {type(query).__name__}"
299339
self._raise_sql_builder_error(msg)
300-
return self
301340

302341
self._with_ctes[alias] = exp.CTE(this=cte_select_expression, alias=exp.to_table(alias))
303342
return self

sqlspec/builder/mixins/_cte_and_set_ops.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ def add_parameter(self, value: Any, name: Optional[str] = None) -> tuple[Any, st
3131
msg = "Method must be provided by QueryBuilder subclass"
3232
raise NotImplementedError(msg)
3333

34+
def _generate_unique_parameter_name(self, base_name: str) -> str:
35+
"""Generate unique parameter name - provided by QueryBuilder."""
36+
msg = "Method must be provided by QueryBuilder subclass"
37+
raise NotImplementedError(msg)
38+
39+
def _update_placeholders_in_expression(
40+
self, expression: exp.Expression, param_mapping: dict[str, str]
41+
) -> exp.Expression:
42+
"""Update parameter placeholders - provided by QueryBuilder."""
43+
msg = "Method must be provided by QueryBuilder subclass"
44+
raise NotImplementedError(msg)
45+
3446
def with_(
3547
self, name: str, query: Union[Any, str], recursive: bool = False, columns: Optional[list[str]] = None
3648
) -> Self:
@@ -69,8 +81,15 @@ def with_(
6981
parameters = built_query.parameters
7082
if parameters:
7183
if isinstance(parameters, dict):
84+
param_mapping = {}
7285
for param_name, param_value in parameters.items():
73-
self.add_parameter(param_value, name=param_name)
86+
unique_name = self._generate_unique_parameter_name(f"{name}_{param_name}")
87+
param_mapping[param_name] = unique_name
88+
self.add_parameter(param_value, name=unique_name)
89+
90+
# Update placeholders in the parsed expression
91+
if cte_expr and param_mapping:
92+
cte_expr = self._update_placeholders_in_expression(cte_expr, param_mapping)
7493
elif isinstance(parameters, (list, tuple)):
7594
for param_value in parameters:
7695
self.add_parameter(param_value)

0 commit comments

Comments
 (0)