1414from typing import Any , Optional , Union
1515
1616from sqlspec .core .cache import CacheKey , get_cache_config , get_default_cache
17- from sqlspec .core .parameters import ParameterStyleConfig , ParameterValidator
18- from sqlspec .core .statement import SQL , StatementConfig
17+ from sqlspec .core .statement import SQL
1918from sqlspec .exceptions import (
2019 MissingDependencyError ,
2120 SQLFileNotFoundError ,
3433# Matches: -- name: query_name (supports hyphens and special suffixes)
3534# We capture the name plus any trailing special characters
3635QUERY_NAME_PATTERN = re .compile (r"^\s*--\s*name\s*:\s*([\w-]+[^\w\s]*)\s*$" , re .MULTILINE | re .IGNORECASE )
37- TRIM_SPECIAL_CHARS = re .compile (r"[^\w-]" )
36+ TRIM_SPECIAL_CHARS = re .compile (r"[^\w. -]" )
3837
3938# Matches: -- dialect: dialect_name (optional dialect specification)
4039DIALECT_PATTERN = re .compile (r"^\s*--\s*dialect\s*:\s*(?P<dialect>[a-zA-Z0-9_]+)\s*$" , re .IGNORECASE | re .MULTILINE )
@@ -581,8 +580,11 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
581580 Raises:
582581 ValueError: If query name already exists.
583582 """
584- if name in self ._queries :
585- existing_source = self ._query_to_file .get (name , "<directly added>" )
583+ # Normalize the name for consistency with file-loaded queries
584+ normalized_name = _normalize_query_name (name )
585+
586+ if normalized_name in self ._queries :
587+ existing_source = self ._query_to_file .get (normalized_name , "<directly added>" )
586588 msg = f"Query name '{ name } ' already exists (source: { existing_source } )"
587589 raise ValueError (msg )
588590
@@ -599,21 +601,16 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
599601 else :
600602 dialect = normalized_dialect
601603
602- statement = NamedStatement (name = name , sql = sql .strip (), dialect = dialect , start_line = 0 )
603- self ._queries [name ] = statement
604- self ._query_to_file [name ] = "<directly added>"
604+ statement = NamedStatement (name = normalized_name , sql = sql .strip (), dialect = dialect , start_line = 0 )
605+ self ._queries [normalized_name ] = statement
606+ self ._query_to_file [normalized_name ] = "<directly added>"
605607
606- def get_sql (
607- self , name : str , parameters : "Optional[Any]" = None , dialect : "Optional[str]" = None , ** kwargs : "Any"
608- ) -> "SQL" :
609- """Get a SQL object by statement name with dialect support.
608+ def get_sql (self , name : str ) -> "SQL" :
609+ """Get a SQL object by statement name.
610610
611611 Args:
612612 name: Name of the statement (from -- name: in SQL file).
613613 Hyphens in names are converted to underscores.
614- parameters: Parameters for the SQL statement.
615- dialect: Optional dialect override.
616- **kwargs: Additional parameters to pass to the SQL object.
617614
618615 Returns:
619616 SQL object ready for execution.
@@ -640,46 +637,11 @@ def get_sql(
640637 raise SQLFileNotFoundError (name , path = f"Statement '{ name } ' not found. Available statements: { available } " )
641638
642639 parsed_statement = self ._queries [safe_name ]
643-
644- effective_dialect = dialect or parsed_statement .dialect
645-
646- if dialect is not None :
647- normalized_dialect = _normalize_dialect (dialect )
648- if normalized_dialect not in SUPPORTED_DIALECTS :
649- suggestions = _get_dialect_suggestions (normalized_dialect )
650- warning_msg = f"Unknown dialect '{ dialect } '"
651- if suggestions :
652- warning_msg += f". Did you mean: { ', ' .join (suggestions )} ?"
653- warning_msg += f". Supported dialects: { ', ' .join (sorted (SUPPORTED_DIALECTS ))} . Using dialect as-is."
654- logger .warning (warning_msg )
655- effective_dialect = dialect .lower ()
656- else :
657- effective_dialect = normalized_dialect
658-
659- sql_kwargs = dict (kwargs )
660- if parameters is not None :
661- sql_kwargs ["parameters" ] = parameters
662-
663640 sqlglot_dialect = None
664- if effective_dialect :
665- sqlglot_dialect = _normalize_dialect_for_sqlglot (effective_dialect )
666-
667- if not effective_dialect and "statement_config" not in sql_kwargs :
668- validator = ParameterValidator ()
669- param_info = validator .extract_parameters (parsed_statement .sql )
670- if param_info :
671- styles = {p .style for p in param_info }
672- if styles :
673- detected_style = next (iter (styles ))
674- sql_kwargs ["statement_config" ] = StatementConfig (
675- parameter_config = ParameterStyleConfig (
676- default_parameter_style = detected_style ,
677- supported_parameter_styles = styles ,
678- preserve_parameter_format = True ,
679- )
680- )
641+ if parsed_statement .dialect :
642+ sqlglot_dialect = _normalize_dialect_for_sqlglot (parsed_statement .dialect )
681643
682- return SQL (parsed_statement .sql , dialect = sqlglot_dialect , ** sql_kwargs )
644+ return SQL (parsed_statement .sql , dialect = sqlglot_dialect )
683645
684646 def get_file (self , path : Union [str , Path ]) -> "Optional[SQLFile]" :
685647 """Get a loaded SQLFile object by path.
0 commit comments