1313from sqlglot .optimizer import optimize
1414from typing_extensions import Self
1515
16- from sqlspec .core .cache import CacheKey , get_cache_config , get_default_cache
16+ from sqlspec .core .cache import get_cache , get_cache_config
1717from sqlspec .core .hashing import hash_optimized_expression
1818from sqlspec .core .parameters import ParameterStyle , ParameterStyleConfig
1919from sqlspec .core .statement import SQL , StatementConfig
@@ -91,6 +91,36 @@ def _initialize_expression(self) -> None:
9191 "QueryBuilder._create_base_expression must return a valid sqlglot expression."
9292 )
9393
94+ def get_expression (self ) -> Optional [exp .Expression ]:
95+ """Get expression reference (no copy).
96+
97+ Returns:
98+ The current SQLGlot expression or None if not set
99+ """
100+ return self ._expression
101+
102+ def set_expression (self , expression : exp .Expression ) -> None :
103+ """Set expression with validation.
104+
105+ Args:
106+ expression: SQLGlot expression to set
107+
108+ Raises:
109+ TypeError: If expression is not a SQLGlot Expression
110+ """
111+ if not isinstance (expression , exp .Expression ):
112+ msg = f"Expected Expression, got { type (expression )} "
113+ raise TypeError (msg )
114+ self ._expression = expression
115+
116+ def has_expression (self ) -> bool :
117+ """Check if expression exists.
118+
119+ Returns:
120+ True if expression is set, False otherwise
121+ """
122+ return self ._expression is not None
123+
94124 @abstractmethod
95125 def _create_base_expression (self ) -> exp .Expression :
96126 """Create the base sqlglot expression for the specific query type.
@@ -307,12 +337,13 @@ def with_cte(self: Self, alias: str, query: "Union[QueryBuilder, exp.Select, str
307337 cte_select_expression : exp .Select
308338
309339 if isinstance (query , QueryBuilder ):
310- if query ._expression is None :
340+ query_expr = query .get_expression ()
341+ if query_expr is None :
311342 self ._raise_sql_builder_error ("CTE query builder has no expression." )
312- if not isinstance (query . _expression , exp .Select ):
313- msg = f"CTE query builder expression must be a Select, got { type (query . _expression ).__name__ } ."
343+ if not isinstance (query_expr , exp .Select ):
344+ msg = f"CTE query builder expression must be a Select, got { type (query_expr ).__name__ } ."
314345 self ._raise_sql_builder_error (msg )
315- cte_select_expression = query . _expression
346+ cte_select_expression = query_expr
316347 param_mapping = self ._merge_cte_parameters (alias , query .parameters )
317348 updated_expression = self ._update_placeholders_in_expression (cte_select_expression , param_mapping )
318349 if not isinstance (updated_expression , exp .Select ):
@@ -398,9 +429,8 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
398429 expression , dialect = dialect_name , schema = self .schema , optimizer_settings = optimizer_settings
399430 )
400431
401- cache_key_obj = CacheKey ((cache_key ,))
402- unified_cache = get_default_cache ()
403- cached_optimized = unified_cache .get (cache_key_obj )
432+ cache = get_cache ()
433+ cached_optimized = cache .get ("optimized" , cache_key )
404434 if cached_optimized :
405435 return cast ("exp.Expression" , cached_optimized )
406436
@@ -409,7 +439,7 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
409439 expression , schema = self .schema , dialect = self .dialect_name , optimizer_settings = optimizer_settings
410440 )
411441
412- unified_cache .put (cache_key_obj , optimized )
442+ cache .put ("optimized" , cache_key , optimized )
413443
414444 except Exception :
415445 return expression
@@ -430,15 +460,14 @@ def to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL":
430460 return self ._to_statement (config )
431461
432462 cache_key_str = self ._generate_builder_cache_key (config )
433- cache_key = CacheKey ((cache_key_str ,))
434463
435- unified_cache = get_default_cache ()
436- cached_sql = unified_cache .get (cache_key )
464+ cache = get_cache ()
465+ cached_sql = cache .get ("builder" , cache_key_str )
437466 if cached_sql is not None :
438467 return cast ("SQL" , cached_sql )
439468
440469 sql_statement = self ._to_statement (config )
441- unified_cache .put (cache_key , sql_statement )
470+ cache .put ("builder" , cache_key_str , sql_statement )
442471
443472 return sql_statement
444473
@@ -531,3 +560,16 @@ def _merge_sql_object_parameters(self, sql_obj: Any) -> None:
531560 def parameters (self ) -> dict [str , Any ]:
532561 """Public access to query parameters."""
533562 return self ._parameters
563+
564+ def set_parameters (self , parameters : dict [str , Any ]) -> None :
565+ """Set query parameters (public API)."""
566+ self ._parameters = parameters .copy ()
567+
568+ @property
569+ def with_ctes (self ) -> "dict[str, exp.CTE]" :
570+ """Get WITH clause CTEs (public API)."""
571+ return dict (self ._with_ctes )
572+
573+ def generate_unique_parameter_name (self , base_name : str ) -> str :
574+ """Generate unique parameter name (public API)."""
575+ return self ._generate_unique_parameter_name (base_name )
0 commit comments