1010from collections .abc import Mapping , Sequence
1111from dataclasses import dataclass , field
1212from enum import Enum
13- from typing import TYPE_CHECKING , Any , Final , Optional , Union
13+ from typing import Any , Final , Optional , Union
1414
15+ from sqlglot import exp
1516from typing_extensions import TypedDict
1617
1718from sqlspec .exceptions import ExtraParameterError , MissingParameterError , ParameterStyleMismatchError
1819from sqlspec .typing import SQLParameterType
1920
20- if TYPE_CHECKING :
21- from sqlglot import exp
22-
2321# Constants
2422MAX_32BIT_INT : Final [int ] = 2147483647
2523
2826 "ParameterConverter" ,
2927 "ParameterInfo" ,
3028 "ParameterStyle" ,
31- "ParameterStyleTransformationState " ,
29+ "ParameterStyleConversionState " ,
3230 "ParameterValidator" ,
3331 "SQLParameterType" ,
3432 "TypedParameter" ,
@@ -169,7 +167,7 @@ class ParameterStyleInfo(TypedDict, total=False):
169167
170168
171169@dataclass
172- class ParameterStyleTransformationState :
170+ class ParameterStyleConversionState :
173171 """Encapsulates all information about parameter style transformation.
174172
175173 This class provides a single source of truth for parameter style conversions,
@@ -213,7 +211,7 @@ class ConvertedParameters:
213211 merged_parameters : "SQLParameterType"
214212 """Parameters after merging from various sources."""
215213
216- conversion_state : ParameterStyleTransformationState
214+ conversion_state : ParameterStyleConversionState
217215 """Complete conversion state for tracking conversions."""
218216
219217
@@ -314,17 +312,13 @@ def get_parameter_style(parameters_info: "list[ParameterInfo]") -> "ParameterSty
314312 """
315313 if not parameters_info :
316314 return ParameterStyle .NONE
317-
318- # Note: This logic prioritizes pyformat if present, then named, then positional.
319315 is_pyformat_named = any (p .style == ParameterStyle .NAMED_PYFORMAT for p in parameters_info )
320316 is_pyformat_positional = any (p .style == ParameterStyle .POSITIONAL_PYFORMAT for p in parameters_info )
321317
322318 if is_pyformat_named :
323319 return ParameterStyle .NAMED_PYFORMAT
324- if is_pyformat_positional : # If only PYFORMAT_POSITIONAL and not PYFORMAT_NAMED
320+ if is_pyformat_positional :
325321 return ParameterStyle .POSITIONAL_PYFORMAT
326-
327- # Simplified logic if not pyformat, checks for any named or any positional
328322 has_named = any (
329323 p .style
330324 in {
@@ -336,13 +330,7 @@ def get_parameter_style(parameters_info: "list[ParameterInfo]") -> "ParameterSty
336330 for p in parameters_info
337331 )
338332 has_positional = any (p .style in {ParameterStyle .QMARK , ParameterStyle .NUMERIC } for p in parameters_info )
339-
340- # If mixed named and positional (non-pyformat), prefer named as dominant.
341- # The choice of NAMED_COLON here is somewhat arbitrary if multiple named styles are mixed.
342333 if has_named :
343- # Could refine to return the style of the first named param encountered, or most frequent.
344- # For simplicity, returning a general named style like NAMED_COLON is often sufficient.
345- # Or, more accurately, find the first named style:
346334 for p_style in (
347335 ParameterStyle .NAMED_COLON ,
348336 ParameterStyle .POSITIONAL_COLON ,
@@ -354,12 +342,11 @@ def get_parameter_style(parameters_info: "list[ParameterInfo]") -> "ParameterSty
354342 return ParameterStyle .NAMED_COLON
355343
356344 if has_positional :
357- # Similarly, could choose QMARK or NUMERIC based on presence.
358345 if any (p .style == ParameterStyle .NUMERIC for p in parameters_info ):
359346 return ParameterStyle .NUMERIC
360- return ParameterStyle .QMARK # Default positional
347+ return ParameterStyle .QMARK
361348
362- return ParameterStyle .NONE # Should not be reached if parameters_info is not empty
349+ return ParameterStyle .NONE
363350
364351 @staticmethod
365352 def determine_parameter_input_type (parameters_info : "list[ParameterInfo]" ) -> "Optional[type]" :
@@ -384,9 +371,8 @@ def determine_parameter_input_type(parameters_info: "list[ParameterInfo]") -> "O
384371 if any (
385372 p .name is not None and p .style not in {ParameterStyle .POSITIONAL_COLON , ParameterStyle .NUMERIC }
386373 for p in parameters_info
387- ): # True for NAMED styles and PYFORMAT_NAMED
374+ ):
388375 return dict
389- # All parameters must have p.name is None or be positional styles (POSITIONAL_COLON, NUMERIC)
390376 if all (
391377 p .name is None or p .style in {ParameterStyle .POSITIONAL_COLON , ParameterStyle .NUMERIC }
392378 for p in parameters_info
@@ -400,9 +386,7 @@ def determine_parameter_input_type(parameters_info: "list[ParameterInfo]") -> "O
400386 "Ambiguous parameter structure for determining input type. "
401387 "Query might contain a mix of named and unnamed styles not typically supported together."
402388 )
403- # Defaulting to dict if any named param is found, as that's the more common requirement for mixed scenarios.
404- # However, strict validation should ideally prevent such mixed styles from being valid.
405- return dict # Or raise an error for unsupported mixed styles.
389+ return dict
406390
407391 def validate_parameters (
408392 self ,
@@ -421,12 +405,7 @@ def validate_parameters(
421405 ParameterStyleMismatchError: When style doesn't match
422406 """
423407 expected_input_type = self .determine_parameter_input_type (parameters_info )
424-
425- # Allow creating SQL statements with placeholders but no parameters
426- # This enables patterns like SQL("SELECT * FROM users WHERE id = ?").as_many([...])
427- # Validation will happen later when parameters are actually provided
428408 if provided_params is None and parameters_info :
429- # Don't raise an error, just return - validation will happen later
430409 return
431410
432411 if (
@@ -707,7 +686,7 @@ def convert_parameters(
707686 self .validator .validate_parameters (parameters_info , merged_params , sql )
708687 if needs_conversion :
709688 transformed_sql , placeholder_map = self ._transform_sql_for_parsing (sql , parameters_info )
710- conversion_state = ParameterStyleTransformationState (
689+ conversion_state = ParameterStyleConversionState (
711690 was_transformed = True ,
712691 original_styles = list ({p .style for p in parameters_info }),
713692 transformation_style = ParameterStyle .NAMED_COLON ,
@@ -716,7 +695,7 @@ def convert_parameters(
716695 )
717696 else :
718697 transformed_sql = sql
719- conversion_state = ParameterStyleTransformationState (
698+ conversion_state = ParameterStyleConversionState (
720699 was_transformed = False ,
721700 original_styles = list ({p .style for p in parameters_info }),
722701 original_param_info = parameters_info ,
@@ -775,10 +754,10 @@ def merge_parameters(
775754 return parameters
776755
777756 if kwargs is not None :
778- return dict (kwargs ) # Make a copy
757+ return dict (kwargs )
779758
780759 if args is not None :
781- return list (args ) # Convert tuple of args to list for consistency and mutability if needed later
760+ return list (args )
782761
783762 return None
784763
@@ -809,53 +788,34 @@ def wrap_parameters_with_types(
809788
810789 def infer_type_from_value (value : Any ) -> tuple [str , "exp.DataType" ]:
811790 """Infer SQL type hint and SQLGlot DataType from Python value."""
812- # Import here to avoid issues
813- from sqlglot import exp
814791
815792 # None/NULL
816793 if value is None :
817794 return "null" , exp .DataType .build ("NULL" )
818-
819- # Boolean
820795 if isinstance (value , bool ):
821796 return "boolean" , exp .DataType .build ("BOOLEAN" )
822-
823- # Integer types
824797 if isinstance (value , int ) and not isinstance (value , bool ):
825798 if abs (value ) > MAX_32BIT_INT :
826799 return "bigint" , exp .DataType .build ("BIGINT" )
827800 return "integer" , exp .DataType .build ("INT" )
828-
829- # Float/Decimal
830801 if isinstance (value , float ):
831802 return "float" , exp .DataType .build ("FLOAT" )
832803 if isinstance (value , Decimal ):
833804 return "decimal" , exp .DataType .build ("DECIMAL" )
834-
835- # Date/Time types
836805 if isinstance (value , datetime ):
837806 return "timestamp" , exp .DataType .build ("TIMESTAMP" )
838807 if isinstance (value , date ):
839808 return "date" , exp .DataType .build ("DATE" )
840809 if isinstance (value , time ):
841810 return "time" , exp .DataType .build ("TIME" )
842-
843- # JSON/Dict
844811 if isinstance (value , dict ):
845812 return "json" , exp .DataType .build ("JSON" )
846-
847- # Array/List
848813 if isinstance (value , (list , tuple )):
849814 return "array" , exp .DataType .build ("ARRAY" )
850-
851815 if isinstance (value , str ):
852816 return "string" , exp .DataType .build ("VARCHAR" )
853-
854- # Bytes
855817 if isinstance (value , bytes ):
856818 return "binary" , exp .DataType .build ("BINARY" )
857-
858- # Default fallback
859819 return "string" , exp .DataType .build ("VARCHAR" )
860820
861821 def wrap_value (value : Any , semantic_name : Optional [str ] = None ) -> Any :
0 commit comments