diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 309b93ba..fc9e5c4d 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -1,30 +1,35 @@ import contextlib +import logging import re from collections.abc import Generator from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast -from adbc_driver_manager.dbapi import Connection -from adbc_driver_manager.dbapi import Cursor as DbapiCursor +from adbc_driver_manager.dbapi import Connection, Cursor -from sqlspec._typing import ArrowTable from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T +from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError +from sqlspec.statement import SQLStatement +from sqlspec.typing import ArrowTable, StatementParameterType if TYPE_CHECKING: from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType __all__ = ("AdbcDriver",) +logger = logging.getLogger("sqlspec") + -# Regex to find :param or %(param)s style placeholders, skipping those inside quotes PARAM_REGEX = re.compile( - r""" - (?P"([^"]|\\")*") | # Double-quoted strings - (?P'([^']|\\')*') | # Single-quoted strings - : (?P[a-zA-Z_][a-zA-Z0-9_]*) | # :var_name - % \( (?P[a-zA-Z_][a-zA-Z0-9_]*) \) s # %(var_name)s + r"""(?"(?:[^"]|"")*") | # Double-quoted strings + (?P'(?:[^']|'')*') | # Single-quoted strings + (?P--.*?\n|\/\*.*?\*\/) | # SQL comments + (?P[:\$])(?P[a-zA-Z_][a-zA-Z0-9_]*) # :name or $name identifier + ) """, - re.VERBOSE, + re.VERBOSE | re.DOTALL, ) @@ -37,16 +42,40 @@ class AdbcDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterPr def __init__(self, connection: "Connection") -> None: """Initialize the ADBC driver adapter.""" self.connection = connection - # Potentially introspect connection.paramstyle here if needed in the future - # For now, assume 'qmark' based on typical ADBC DBAPI behavior + self.dialect = self._get_dialect(connection) + + @staticmethod + def _get_dialect(connection: "Connection") -> str: # noqa: PLR0911 + """Get the database dialect based on the driver name. + + Args: + connection: The ADBC connection object. + + Returns: + The database dialect. + """ + driver_name = connection.adbc_get_info()["vendor_name"].lower() + if "postgres" in driver_name: + return "postgres" + if "bigquery" in driver_name: + return "bigquery" + if "sqlite" in driver_name: + return "sqlite" + if "duckdb" in driver_name: + return "duckdb" + if "mysql" in driver_name: + return "mysql" + if "snowflake" in driver_name: + return "snowflake" + return "postgres" # default to postgresql dialect @staticmethod - def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "DbapiCursor": + def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor": return connection.cursor(*args, **kwargs) @contextmanager - def _with_cursor(self, connection: "Connection") -> Generator["DbapiCursor", None, None]: - cursor: DbapiCursor = self._cursor(connection) + def _with_cursor(self, connection: "Connection") -> Generator["Cursor", None, None]: + cursor = self._cursor(connection) try: yield cursor finally: @@ -54,79 +83,92 @@ def _with_cursor(self, connection: "Connection") -> Generator["DbapiCursor", Non cursor.close() # type: ignore[no-untyped-call] def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name or %(name)s) to positional parameters specified by `self.param_style` - if the input parameters are a dictionary. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters - (always a tuple or None if the input was a dictionary, otherwise the original type). - - Raises: - ValueError: If a named parameter in the SQL is not found in the dictionary - or if a parameter in the dictionary is not used in the SQL. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return self._process_sql_statement(sql), parameters - - processed_sql = "" - processed_params_list: list[Any] = [] - last_end = 0 - found_params: set[str] = set() - - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - # Skip placeholders within quotes - continue - - # Get name from whichever group matched - var_name = match.group("var_name_colon") or match.group("var_name_perc") - - if var_name is None: # Should not happen with the new regex structure - continue - - if var_name not in parameters: - placeholder = match.group(0) # Get the full matched placeholder - msg = f"Named parameter '{placeholder}' found in SQL but not provided in parameters dictionary." - raise ValueError(msg) - - # Append segment before the placeholder - processed_sql += sql[last_end : match.start()] - # Append the driver's positional placeholder - processed_sql += self.param_style - processed_params_list.append(parameters[var_name]) - found_params.add(var_name) - last_end = match.end() - - # Append the rest of the SQL string - processed_sql += sql[last_end:] - - # Check if all provided parameters were used - unused_params = set(parameters.keys()) - found_params - if unused_params: - msg = f"Parameters provided but not found in SQL: {unused_params}" - # Depending on desired strictness, this could be a warning or an error - # For now, let's raise an error for clarity - raise ValueError(msg) - - return self._process_sql_statement(processed_sql), tuple(processed_params_list) + # Determine effective parameter type *before* calling SQLStatement + merged_params_type = dict if kwargs else type(parameters) + + # If ADBC + sqlite/duckdb + dictionary params, handle conversion manually + if self.dialect in {"sqlite", "duckdb"} and merged_params_type is dict: + logger.debug( + "ADBC/%s with dict params; bypassing SQLStatement conversion, manually converting to '?' positional.", + self.dialect, + ) + + # Combine parameters and kwargs into the actual dictionary to use + parameter_dict = {} # type: ignore[var-annotated] + if isinstance(parameters, dict): + parameter_dict.update(parameters) + if kwargs: + parameter_dict.update(kwargs) + + # Define regex locally to find :name or $name + + processed_sql_parts: list[str] = [] + ordered_params = [] + last_end = 0 + found_params_regex: list[str] = [] + + for match in PARAM_REGEX.finditer(sql): # Use original sql + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): + var_name = match.group("var_name") + leading_char = match.group("lead") # : or $ + found_params_regex.append(var_name) + # Use match span directly for replacement + start = match.start() + end = match.end() + + if var_name not in parameter_dict: + msg = f"Named parameter '{leading_char}{var_name}' found in SQL but not provided. SQL: {sql}" + raise SQLParsingError(msg) + + processed_sql_parts.extend((sql[last_end:start], "?")) # Force ? style + ordered_params.append(parameter_dict[var_name]) + last_end = end + + processed_sql_parts.append(sql[last_end:]) + + if not found_params_regex and parameter_dict: + msg = f"ADBC/{self.dialect}: Dict params provided, but no :name or $name placeholders found. SQL: {sql}" + raise ParameterStyleMismatchError(msg) + + # Key validation + provided_keys = set(parameter_dict.keys()) + missing_keys = set(found_params_regex) - provided_keys + if missing_keys: + msg = ( + f"Named parameters found in SQL ({found_params_regex}) but not provided: {missing_keys}. SQL: {sql}" + ) + raise SQLParsingError(msg) + extra_keys = provided_keys - set(found_params_regex) + if extra_keys: + logger.debug("Extra parameters provided for ADBC/%s: %s", self.dialect, extra_keys) + # Allow extra keys + + final_sql = "".join(processed_sql_parts) + final_params = tuple(ordered_params) + return final_sql, final_params + # For all other cases (other dialects, or non-dict params for sqlite/duckdb), + # use the standard SQLStatement processing. + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + return stmt.process() def select( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -134,7 +176,7 @@ def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] results = cursor.fetchall() # pyright: ignore @@ -152,8 +194,10 @@ def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -161,7 +205,7 @@ def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -176,8 +220,10 @@ def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -185,7 +231,7 @@ def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -201,8 +247,10 @@ def select_value( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -210,7 +258,7 @@ def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -224,8 +272,10 @@ def select_value_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -233,7 +283,7 @@ def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -248,7 +298,9 @@ def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -256,7 +308,7 @@ def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -267,8 +319,10 @@ def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -276,25 +330,31 @@ def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - column_names: list[str] = [] - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if len(result) == 0: # pyright: ignore[reportUnknownArgumentType] + if not result: return None + + first_row = result[0] + column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) # pyright: ignore[reportUnknownArgumentType] - return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType] + + result_dict = dict(zip(column_names, first_row)) + + if schema_type is None: + return result_dict + return cast("ModelDTOT", schema_type(**result_dict)) def execute_script( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -308,33 +368,6 @@ def execute_script( cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cast("str", cursor.statusmessage) if hasattr(cursor, "statusmessage") else "DONE" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] - def execute_script_returning( - self, - sql: str, - parameters: Optional["StatementParameterType"] = None, - /, - connection: Optional["Connection"] = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - column_names: list[str] = [] - - with self._with_cursor(connection) as cursor: - cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] - result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if len(result) == 0: # pyright: ignore[reportUnknownArgumentType] - return None - column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) # pyright: ignore[reportUnknownArgumentType] - return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] - # --- Arrow Bulk Operations --- def select_arrow( # pyright: ignore[reportUnknownParameterType] @@ -342,7 +375,9 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> "ArrowTable": """Execute a SQL query and return results as an Apache Arrow Table. @@ -350,7 +385,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] The results of the query as an Apache Arrow Table. """ conn = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(conn) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index 06e4a5f6..b61ee3b2 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -17,6 +17,7 @@ class AiosqliteDriver(AsyncDriverAdapterProtocol["Connection"]): """SQLite Async Driver Adapter.""" connection: "Connection" + dialect: str = "sqlite" def __init__(self, connection: "Connection") -> None: self.connection = connection @@ -33,42 +34,15 @@ async def _with_cursor(self, connection: "Connection") -> "AsyncGenerator[Cursor finally: await cursor.close() - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters (?) for SQLite. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters - - # Convert named parameters to positional parameters - processed_sql = sql - processed_params: list[Any] = [] - for key, value in parameters.items(): - # Replace :key with ? in the SQL - processed_sql = processed_sql.replace(f":{key}", "?") - processed_params.append(value) - - return processed_sql, tuple(processed_params) - async def select( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -76,7 +50,7 @@ async def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] results = await cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -92,8 +66,10 @@ async def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -101,7 +77,7 @@ async def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -116,8 +92,10 @@ async def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -125,7 +103,7 @@ async def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -141,8 +119,10 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -150,7 +130,7 @@ async def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType] @@ -164,8 +144,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -173,8 +155,7 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType] @@ -189,7 +170,9 @@ async def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -197,7 +180,7 @@ async def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -208,8 +191,10 @@ async def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -217,7 +202,7 @@ async def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -234,7 +219,9 @@ async def execute_script( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -242,7 +229,7 @@ async def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -253,8 +240,10 @@ async def execute_script_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Execute a script and return result. @@ -262,7 +251,7 @@ async def execute_script_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index ff656100..2ffa02be 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -18,6 +18,7 @@ class AsyncmyDriver(AsyncDriverAdapterProtocol["Connection"]): """Asyncmy MySQL/MariaDB Driver Adapter.""" connection: "Connection" + dialect: str = "mysql" def __init__(self, connection: "Connection") -> None: self.connection = connection @@ -40,8 +41,10 @@ async def select( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -49,7 +52,7 @@ async def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) results = await cursor.fetchall() @@ -65,8 +68,10 @@ async def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -74,7 +79,7 @@ async def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) result = await cursor.fetchone() @@ -89,8 +94,10 @@ async def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -98,7 +105,7 @@ async def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) result = await cursor.fetchone() @@ -114,8 +121,10 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -123,7 +132,7 @@ async def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -140,8 +149,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -149,7 +160,7 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -168,7 +179,9 @@ async def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -176,7 +189,7 @@ async def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -187,8 +200,10 @@ async def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -196,7 +211,7 @@ async def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) column_names: list[str] = [] async with self._with_cursor(connection) as cursor: @@ -214,7 +229,9 @@ async def execute_script( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -222,34 +239,8 @@ async def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) return "DONE" - - async def execute_script_returning( - self, - sql: str, - parameters: Optional["StatementParameterType"] = None, - /, - connection: Optional["Connection"] = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - async with self._with_cursor(connection) as cursor: - await cursor.execute(sql, parameters) - result = await cursor.fetchone() - if result is None: - return None - column_names = [c[0] for c in cursor.description or []] - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) - return dict(zip(column_names, result)) diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index 48ad1f67..28fa4b63 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -1,9 +1,13 @@ +import logging +import re from typing import TYPE_CHECKING, Any, Optional, Union, cast from asyncpg import Connection from typing_extensions import TypeAlias from sqlspec.base import AsyncDriverAdapterProtocol, T +from sqlspec.exceptions import SQLParsingError +from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: from asyncpg.connection import Connection @@ -13,6 +17,18 @@ __all__ = ("AsyncpgConnection", "AsyncpgDriver") +logger = logging.getLogger("sqlspec") + +# Regex to find '?' placeholders, skipping those inside quotes or SQL comments +# Simplified version, assumes standard SQL quoting/comments +QMARK_REGEX = re.compile( + r"""(?P"[^"]*") | # Double-quoted strings + (?P\'[^\']*\') | # Single-quoted strings + (?P--[^\n]*|/\*.*?\*/) | # SQL comments (single/multi-line) + (?P\?) # The question mark placeholder + """, + re.VERBOSE | re.DOTALL, +) AsyncpgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]" # pyright: ignore[reportMissingTypeArgument] @@ -21,23 +37,174 @@ class AsyncpgDriver(AsyncDriverAdapterProtocol["AsyncpgConnection"]): """AsyncPG Postgres Driver Adapter.""" connection: "AsyncpgConnection" + dialect: str = "postgres" def __init__(self, connection: "AsyncpgConnection") -> None: self.connection = connection - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Union[tuple[Any, ...], list[Any], dict[str, Any]]]": - sql, parameters = super()._process_sql_params(sql, parameters) - return sql, parameters if parameters is not None else () + def _process_sql_params( # noqa: C901, PLR0912, PLR0915 + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + # Use SQLStatement for parameter validation and merging first + # It also handles potential dialect-specific logic if implemented there. + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + sql, parameters = stmt.process() + + # Case 1: Parameters are effectively a dictionary (either passed as dict or via kwargs merged by SQLStatement) + if isinstance(parameters, dict): + processed_sql_parts: list[str] = [] + ordered_params = [] + last_end = 0 + param_index = 1 + found_params_regex: list[str] = [] + + # Manually parse the PROCESSED SQL for :name -> $n conversion + for match in PARAM_REGEX.finditer(sql): + # Skip matches inside quotes or comments + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): # Finds :var_name + var_name = match.group("var_name") + found_params_regex.append(var_name) + start = match.start("var_name") - 1 # Include the ':' + end = match.end("var_name") + + # SQLStatement should have already validated parameter existence, + # but we double-check here during ordering. + if var_name not in parameters: + # This should ideally not happen if SQLStatement validation is robust. + msg = ( + f"Named parameter ':{var_name}' found in SQL but missing from processed parameters. " + f"Processed SQL: {sql}" + ) + raise SQLParsingError(msg) + + processed_sql_parts.extend((sql[last_end:start], f"${param_index}")) + ordered_params.append(parameters[var_name]) + last_end = end + param_index += 1 + + processed_sql_parts.append(sql[last_end:]) + final_sql = "".join(processed_sql_parts) + + # --- Validation --- + # Check if named placeholders were found if dict params were provided + # SQLStatement might handle this validation, but a warning here can be useful. + if not found_params_regex and parameters: + logger.warning( + "Dict params provided (%s), but no :name placeholders found. SQL: %s", + list(parameters.keys()), + sql, + ) + # If no placeholders, return original SQL from SQLStatement and empty tuple for asyncpg + return sql, () + + # Additional checks (potentially redundant if SQLStatement covers them): + # 1. Ensure all found placeholders have corresponding params (covered by check inside loop) + # 2. Ensure all provided params correspond to a placeholder + provided_keys = set(parameters.keys()) + found_keys = set(found_params_regex) + unused_keys = provided_keys - found_keys + if unused_keys: + # SQLStatement might handle this, but log a warning just in case. + logger.warning( + "Parameters provided but not used in SQL: %s. SQL: %s", + unused_keys, + sql, + ) + + return final_sql, tuple(ordered_params) # asyncpg expects a sequence + + # Case 2: Parameters are effectively a sequence/scalar (merged by SQLStatement) + if isinstance(parameters, (list, tuple)): + # Parameters are a sequence, need to convert ? -> $n + sequence_processed_parts: list[str] = [] + param_index = 1 + last_end = 0 + qmark_found = False + + # Manually parse the PROCESSED SQL to find '?' outside comments/quotes and convert to $n + for match in QMARK_REGEX.finditer(sql): + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue # Skip quotes and comments + + if match.group("qmark"): + qmark_found = True + start = match.start("qmark") + end = match.end("qmark") + sequence_processed_parts.extend((sql[last_end:start], f"${param_index}")) + last_end = end + param_index += 1 + + sequence_processed_parts.append(sql[last_end:]) + final_sql = "".join(sequence_processed_parts) + + # --- Validation --- + # Check if '?' was found if parameters were provided + if parameters and not qmark_found: + # SQLStatement might allow this, log a warning. + logger.warning( + "Sequence/scalar parameters provided, but no '?' placeholders found. SQL: %s", + sql, + ) + # Return PROCESSED SQL from SQLStatement as no conversion happened here + return sql, parameters + + # Check parameter count match (using count from manual parsing vs count from stmt) + expected_params = param_index - 1 + actual_params = len(parameters) + if expected_params != actual_params: + msg = ( + f"Parameter count mismatch: Processed SQL expected {expected_params} parameters ('$n'), " + f"but {actual_params} were provided by SQLStatement. " + f"Final Processed SQL: {final_sql}" + ) + raise SQLParsingError(msg) + + return final_sql, parameters + + # Case 3: Parameters are None (as determined by SQLStatement) + # processed_params is None + # Check if the SQL contains any placeholders unexpectedly + # Check for :name style + named_placeholders_found = False + for match in PARAM_REGEX.finditer(sql): + if not (match.group("dquote") or match.group("squote") or match.group("comment")) and match.group( + "var_name" + ): + named_placeholders_found = True + break + if named_placeholders_found: + msg = f"Processed SQL contains named parameters (:name) but no parameters were provided. SQL: {sql}" + raise SQLParsingError(msg) + + # Check for ? style + qmark_placeholders_found = False + for match in QMARK_REGEX.finditer(sql): + if not (match.group("dquote") or match.group("squote") or match.group("comment")) and match.group("qmark"): + qmark_placeholders_found = True + break + if qmark_placeholders_found: + msg = f"Processed SQL contains positional parameters (?) but no parameters were provided. SQL: {sql}" + raise SQLParsingError(msg) + + # No parameters provided and none found in SQL, return original SQL from SQLStatement and empty tuple + return sql, () # asyncpg expects a sequence, even if empty async def select( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -46,14 +213,16 @@ async def select( parameters: Query parameters. connection: Optional connection to use. schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. Returns: List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else {} - results = await connection.fetch(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + results = await connection.fetch(sql, *parameters) # pyright: ignore if not results: return [] if schema_type is None: @@ -65,8 +234,10 @@ async def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -75,14 +246,15 @@ async def select_one( parameters: Query parameters. connection: Optional connection to use. schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. Returns: The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else {} + result = await connection.fetchrow(sql, *parameters) # pyright: ignore result = self.check_not_found(result) if schema_type is None: @@ -95,8 +267,10 @@ async def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -105,49 +279,62 @@ async def select_one_or_none( parameters: Query parameters. connection: Optional connection to use. schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. Returns: The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - result = self.check_not_found(result) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else {} + result = await connection.fetchrow(sql, *parameters) # pyright: ignore + if result is None: + return None if schema_type is None: # Always return as dictionary - return dict(result.items()) # type: ignore[attr-defined] - return cast("ModelDTOT", schema_type(**dict(result.items()))) # type: ignore[attr-defined] + return dict(result.items()) + return cast("ModelDTOT", schema_type(**dict(result.items()))) async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. + Args: + sql: SQL statement. + parameters: Query parameters. + connection: Optional connection to use. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. + Returns: The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - result = await connection.fetchval(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else {} + result = await connection.fetchval(sql, *parameters) # pyright: ignore result = self.check_not_found(result) if schema_type is None: - return result[0] - return schema_type(result[0]) # type: ignore[call-arg] + return result + return schema_type(result) # type: ignore[call-arg] async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -155,21 +342,23 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - result = await connection.fetchval(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else {} + result = await connection.fetchval(sql, *parameters) # pyright: ignore if result is None: return None if schema_type is None: - return result[0] - return schema_type(result[0]) # type: ignore[call-arg] + return result + return schema_type(result) # type: ignore[call-arg] async def insert_update_delete( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -177,14 +366,15 @@ async def insert_update_delete( sql: SQL statement. parameters: Query parameters. connection: Optional connection to use. + **kwargs: Additional keyword arguments. Returns: Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - status = await connection.execute(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else {} + status = await connection.execute(sql, *parameters) # pyright: ignore # AsyncPG returns a string like "INSERT 0 1" where the last number is the affected rows try: return int(status.split()[-1]) # pyright: ignore[reportUnknownMemberType] @@ -196,24 +386,27 @@ async def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Insert, update, or delete data from the database and return result. + """Insert, update, or delete data from the database and return the affected row. Args: sql: SQL statement. parameters: Query parameters. connection: Optional connection to use. schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. Returns: - The first row of results. + The affected row data as either a model instance or dictionary. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else {} + result = await connection.fetchrow(sql, *parameters) # pyright: ignore if result is None: return None if schema_type is None: @@ -226,7 +419,9 @@ async def execute_script( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -234,41 +429,16 @@ async def execute_script( sql: SQL statement. parameters: Query parameters. connection: Optional connection to use. + **kwargs: Additional keyword arguments. Returns: Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - return await connection.execute(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else {} + return await connection.execute(sql, *parameters) # pyright: ignore - async def execute_script_returning( - self, - sql: str, - parameters: Optional["StatementParameterType"] = None, - /, - connection: Optional["AsyncpgConnection"] = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Args: - sql: SQL statement. - parameters: Query parameters. - connection: Optional connection to use. - schema_type: Optional schema class for the result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - if result is None: - return None - if schema_type is None: - # Always return as dictionary - return dict(result.items()) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - return cast("ModelDTOT", schema_type(**dict(result.items()))) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType] + def _connection(self, connection: Optional["AsyncpgConnection"] = None) -> "AsyncpgConnection": + """Return the connection to use. If None, use the default connection.""" + return connection if connection is not None else self.connection diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 6b01453c..18a4abec 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -1,7 +1,6 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast -from sqlspec._typing import ArrowTable from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T if TYPE_CHECKING: @@ -19,6 +18,7 @@ class DuckDBDriver(SyncArrowBulkOperationsMixin["DuckDBPyConnection"], SyncDrive connection: "DuckDBPyConnection" use_cursor: bool = True + dialect: str = "duckdb" def __init__(self, connection: "DuckDBPyConnection", use_cursor: bool = True) -> None: self.connection = connection @@ -27,7 +27,6 @@ def __init__(self, connection: "DuckDBPyConnection", use_cursor: bool = True) -> # --- Helper Methods --- # def _cursor(self, connection: "DuckDBPyConnection") -> "DuckDBPyConnection": if self.use_cursor: - # Ignore lack of type hint on cursor() return connection.cursor() return connection @@ -40,20 +39,22 @@ def _with_cursor(self, connection: "DuckDBPyConnection") -> "Generator[DuckDBPyC finally: cursor.close() else: - yield connection # Yield the connection directly + yield connection - # --- Public API Methods (Original Implementation + _process_sql_params) --- # + # --- Public API Methods --- # def select( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] results = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -71,12 +72,13 @@ def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -93,12 +95,13 @@ def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -115,12 +118,13 @@ def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[DuckDBPyConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -134,11 +138,13 @@ def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[DuckDBPyConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -153,10 +159,12 @@ def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, + **kwargs: Any, ) -> int: connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return getattr(cursor, "rowcount", -1) # pyright: ignore[reportUnknownMemberType] @@ -166,11 +174,13 @@ def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -182,44 +192,17 @@ def insert_update_delete_returning( # Always return dictionaries return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters (?) for DuckDB. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters - - # Convert named parameters to positional parameters - processed_sql = sql - processed_params: list[Any] = [] - for key, value in parameters.items(): - # Replace :key with ? in the SQL - processed_sql = processed_sql.replace(f":{key}", "?") - processed_params.append(value) - - return processed_sql, tuple(processed_params) - def execute_script( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, + **kwargs: Any, ) -> str: connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cast("str", getattr(cursor, "statusmessage", "DONE")) # pyright: ignore[reportUnknownMemberType] @@ -231,17 +214,12 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[DuckDBPyConnection]" = None, - ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] - """Execute a SQL query and return results as an Apache Arrow Table. - - Returns: - An Apache Arrow Table containing the query results. - """ - - conn = self._connection(connection) - processed_sql, processed_params = self._process_sql_params(sql, parameters) - - with self._with_cursor(conn) as cursor: - cursor.execute(processed_sql, processed_params) # pyright: ignore[reportUnknownMemberType] + **kwargs: Any, + ) -> "ArrowTable": + connection = self._connection(connection) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + with self._with_cursor(connection) as cursor: + cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cast("ArrowTable", cursor.fetch_arrow_table()) # pyright: ignore[reportUnknownMemberType] diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 910c62e9..40e9cabb 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -25,6 +25,7 @@ class OracleSyncDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAda """Oracle Sync Driver Adapter.""" connection: "Connection" + dialect: str = "oracle" def __init__(self, connection: "Connection") -> None: self.connection = connection @@ -43,16 +44,25 @@ def select( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. + Args: + sql: The SQL query string. + parameters: The parameters for the query (dict, tuple, list, or None). + connection: Optional connection override. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. + Returns: List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] results = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -71,16 +81,25 @@ def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. + Args: + sql: The SQL query string. + parameters: The parameters for the query (dict, tuple, list, or None). + connection: Optional connection override. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. + Returns: The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -100,8 +119,10 @@ def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -109,7 +130,7 @@ def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -131,8 +152,10 @@ def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -140,7 +163,7 @@ def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -156,8 +179,10 @@ def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -165,7 +190,7 @@ def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -183,7 +208,9 @@ def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -191,7 +218,7 @@ def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -202,8 +229,10 @@ def insert_update_delete_returning( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -211,7 +240,7 @@ def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -233,7 +262,9 @@ def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -241,7 +272,7 @@ def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -252,7 +283,9 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] """Execute a SQL query and return results as an Apache Arrow Table. @@ -261,7 +294,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) results = connection.fetch_df_all(sql, parameters) return cast("ArrowTable", ArrowTable.from_arrays(arrays=results.column_arrays(), names=results.column_names())) # pyright: ignore @@ -272,6 +305,7 @@ class OracleAsyncDriver( """Oracle Async Driver Adapter.""" connection: "AsyncConnection" + dialect: str = "oracle" def __init__(self, connection: "AsyncConnection") -> None: self.connection = connection @@ -290,8 +324,10 @@ async def select( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -299,7 +335,7 @@ async def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -319,8 +355,10 @@ async def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -328,7 +366,7 @@ async def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -347,8 +385,10 @@ async def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -356,7 +396,7 @@ async def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -378,8 +418,10 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -387,7 +429,7 @@ async def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -403,8 +445,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -412,7 +456,7 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -430,7 +474,9 @@ async def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -438,7 +484,7 @@ async def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -449,8 +495,10 @@ async def insert_update_delete_returning( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -458,7 +506,7 @@ async def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -480,7 +528,9 @@ async def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -488,49 +538,20 @@ async def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return str(cursor.rowcount) # pyright: ignore[reportUnknownMemberType] - async def execute_script_returning( - self, - sql: str, - parameters: "Optional[StatementParameterType]" = None, - /, - connection: "Optional[AsyncConnection]" = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - async with self._with_cursor(connection) as cursor: - await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] - result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - - if result is None: - return None - - # Get column names - column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType] - # Always return dictionaries - return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] - async def select_arrow( # pyright: ignore[reportUnknownParameterType] self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] """Execute a SQL query asynchronously and return results as an Apache Arrow Table. @@ -538,12 +559,13 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: The SQL query string. parameters: Parameters for the query. connection: Optional connection override. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: An Apache Arrow Table containing the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) results = await connection.fetch_df_all(sql, parameters) return ArrowTable.from_arrays(arrays=results.column_arrays(), names=results.column_names()) # pyright: ignore diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index bc7c4fad..345593c0 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -1,9 +1,12 @@ +import logging from contextlib import asynccontextmanager, contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast from psycopg.rows import dict_row -from sqlspec.base import PARAM_REGEX, AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol, T +from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol, T +from sqlspec.exceptions import SQLParsingError +from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator @@ -12,6 +15,8 @@ from sqlspec.typing import ModelDTOT, StatementParameterType +logger = logging.getLogger("sqlspec") + __all__ = ("PsycopgAsyncDriver", "PsycopgSyncDriver") @@ -19,11 +24,63 @@ class PsycopgSyncDriver(SyncDriverAdapterProtocol["Connection"]): """Psycopg Sync Driver Adapter.""" connection: "Connection" - param_style: str = "%s" + dialect: str = "postgres" def __init__(self, connection: "Connection") -> None: self.connection = connection + def _process_sql_params( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Process SQL and parameters, converting :name -> %(name)s if needed.""" + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + processed_sql, processed_params = stmt.process() + + if isinstance(processed_params, dict): + parameter_dict = processed_params + processed_sql_parts: list[str] = [] + last_end = 0 + found_params_regex: list[str] = [] + + for match in PARAM_REGEX.finditer(processed_sql): + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): + var_name = match.group("var_name") + found_params_regex.append(var_name) + start = match.start("var_name") - 1 + end = match.end("var_name") + + if var_name not in parameter_dict: + msg = ( + f"Named parameter ':{var_name}' found in SQL but missing from processed parameters. " + f"Processed SQL: {processed_sql}" + ) + raise SQLParsingError(msg) + + processed_sql_parts.extend((processed_sql[last_end:start], f"%({var_name})s")) + last_end = end + + processed_sql_parts.append(processed_sql[last_end:]) + final_sql = "".join(processed_sql_parts) + + if not found_params_regex and parameter_dict: + logger.warning( + "Dict params provided (%s), but no :name placeholders found. SQL: %s", + list(parameter_dict.keys()), + processed_sql, + ) + return processed_sql, parameter_dict + + return final_sql, parameter_dict + + return processed_sql, processed_params + @staticmethod @contextmanager def _with_cursor(connection: "Connection") -> "Generator[Any, None, None]": @@ -33,75 +90,15 @@ def _with_cursor(connection: "Connection") -> "Generator[Any, None, None]": finally: cursor.close() - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters (%s) - if the input parameters are a dictionary. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters - (always a tuple or None if the input was a dictionary, otherwise the original type). - - Raises: - ValueError: If a named parameter in the SQL is not found in the dictionary - or if a parameter in the dictionary is not used in the SQL. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters - - processed_sql = "" - processed_params_list: list[Any] = [] - last_end = 0 - found_params: set[str] = set() - - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - # Skip placeholders within quotes - continue - - var_name = match.group("var_name") - if var_name is None: # Should not happen with the regex, but safeguard - continue - - if var_name not in parameters: - msg = f"Named parameter ':{var_name}' found in SQL but not provided in parameters dictionary." - raise ValueError(msg) - - # Append segment before the placeholder + the driver's positional placeholder - processed_sql += sql[last_end : match.start("var_name") - 1] + "%s" - processed_params_list.append(parameters[var_name]) - found_params.add(var_name) - last_end = match.end("var_name") - - # Append the rest of the SQL string - processed_sql += sql[last_end:] - - # Check if all provided parameters were used - unused_params = set(parameters.keys()) - found_params - if unused_params: - msg = f"Parameters provided but not found in SQL: {unused_params}" - # Depending on desired strictness, this could be a warning or an error - # For now, let's raise an error for clarity - raise ValueError(msg) - - return processed_sql, tuple(processed_params_list) - def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, - connection: "Optional[Connection]" = None, + *, schema_type: "Optional[type[ModelDTOT]]" = None, + connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -109,7 +106,7 @@ def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) results = cursor.fetchall() @@ -125,8 +122,10 @@ def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -134,8 +133,7 @@ def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) row = cursor.fetchone() @@ -149,8 +147,10 @@ def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -158,8 +158,7 @@ def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) row = cursor.fetchone() @@ -174,8 +173,10 @@ def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -183,13 +184,13 @@ def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) row = cursor.fetchone() row = self.check_not_found(row) - val = next(iter(row)) + val = next(iter(row.values())) if row else None + val = self.check_not_found(val) if schema_type is not None: return schema_type(val) # type: ignore[call-arg] return val @@ -199,8 +200,10 @@ def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -208,14 +211,15 @@ def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) row = cursor.fetchone() if row is None: return None - val = next(iter(row)) + val = next(iter(row.values())) if row else None + if val is None: + return None if schema_type is not None: return schema_type(val) # type: ignore[call-arg] return val @@ -225,27 +229,30 @@ def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> int: - """Insert, update, or delete data from the database. + """Execute an INSERT, UPDATE, or DELETE query and return the number of affected rows. Returns: - Row count affected by the operation. + The number of rows affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) - return cursor.rowcount if hasattr(cursor, "rowcount") else -1 + return getattr(cursor, "rowcount", -1) # pyright: ignore[reportUnknownMemberType] def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -253,8 +260,7 @@ def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) result = cursor.fetchone() @@ -271,7 +277,9 @@ def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -279,49 +287,73 @@ def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - with self._with_cursor(connection) as cursor: - cursor.execute(sql, parameters) - return str(cursor.rowcount) - - def execute_script_returning( - self, - sql: str, - parameters: "Optional[StatementParameterType]" = None, - /, - connection: "Optional[Connection]" = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) - result = cursor.fetchone() - - if result is None: - return None - - if schema_type is not None: - return cast("ModelDTOT", schema_type(**result)) # pyright: ignore[reportUnknownArgumentType] - return cast("dict[str, Any]", result) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + return str(cursor.statusmessage) if cursor.statusmessage is not None else "DONE" class PsycopgAsyncDriver(AsyncDriverAdapterProtocol["AsyncConnection"]): """Psycopg Async Driver Adapter.""" connection: "AsyncConnection" - param_style: str = "%s" + dialect: str = "postgres" def __init__(self, connection: "AsyncConnection") -> None: self.connection = connection + def _process_sql_params( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Process SQL and parameters, converting :name -> %(name)s if needed.""" + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + processed_sql, processed_params = stmt.process() + + if isinstance(processed_params, dict): + parameter_dict = processed_params + processed_sql_parts: list[str] = [] + last_end = 0 + found_params_regex: list[str] = [] + + for match in PARAM_REGEX.finditer(processed_sql): + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): + var_name = match.group("var_name") + found_params_regex.append(var_name) + start = match.start("var_name") - 1 + end = match.end("var_name") + + if var_name not in parameter_dict: + msg = ( + f"Named parameter ':{var_name}' found in SQL but missing from processed parameters. " + f"Processed SQL: {processed_sql}" + ) + raise SQLParsingError(msg) + + processed_sql_parts.extend((processed_sql[last_end:start], f"%({var_name})s")) + last_end = end + + processed_sql_parts.append(processed_sql[last_end:]) + final_sql = "".join(processed_sql_parts) + + if not found_params_regex and parameter_dict: + logger.warning( + "Dict params provided (%s), but no :name placeholders found. SQL: %s", + list(parameter_dict.keys()), + processed_sql, + ) + return processed_sql, parameter_dict + + return final_sql, parameter_dict + + return processed_sql, processed_params + @staticmethod @asynccontextmanager async def _with_cursor(connection: "AsyncConnection") -> "AsyncGenerator[Any, None]": @@ -331,75 +363,15 @@ async def _with_cursor(connection: "AsyncConnection") -> "AsyncGenerator[Any, No finally: await cursor.close() - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters (%s) - if the input parameters are a dictionary. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters - (always a tuple or None if the input was a dictionary, otherwise the original type). - - Raises: - ValueError: If a named parameter in the SQL is not found in the dictionary - or if a parameter in the dictionary is not used in the SQL. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters - - processed_sql = "" - processed_params_list: list[Any] = [] - last_end = 0 - found_params: set[str] = set() - - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - # Skip placeholders within quotes - continue - - var_name = match.group("var_name") - if var_name is None: # Should not happen with the regex, but safeguard - continue - - if var_name not in parameters: - msg = f"Named parameter ':{var_name}' found in SQL but not provided in parameters dictionary." - raise ValueError(msg) - - # Append segment before the placeholder + the driver's positional placeholder - processed_sql += sql[last_end : match.start("var_name") - 1] + "%s" - processed_params_list.append(parameters[var_name]) - found_params.add(var_name) - last_end = match.end("var_name") - - # Append the rest of the SQL string - processed_sql += sql[last_end:] - - # Check if all provided parameters were used - unused_params = set(parameters.keys()) - found_params - if unused_params: - msg = f"Parameters provided but not found in SQL: {unused_params}" - # Depending on desired strictness, this could be a warning or an error - # For now, let's raise an error for clarity - raise ValueError(msg) - - return processed_sql, tuple(processed_params_list) - async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -407,7 +379,7 @@ async def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) results: list[Union[ModelDTOT, dict[str, Any]]] = [] async with self._with_cursor(connection) as cursor: @@ -424,8 +396,10 @@ async def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -433,7 +407,7 @@ async def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -448,8 +422,10 @@ async def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, - connection: "Optional[AsyncConnection]" = None, + *, schema_type: "Optional[type[ModelDTOT]]" = None, + connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -457,7 +433,7 @@ async def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -473,22 +449,25 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, - ) -> "Optional[Union[T, Any]]": + **kwargs: Any, + ) -> "Union[T, Any]": """Fetch a single value from the database. Returns: The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) row = await cursor.fetchone() row = self.check_not_found(row) - val = next(iter(row)) + val = next(iter(row.values())) if row else None + val = self.check_not_found(val) if schema_type is not None: return schema_type(val) # type: ignore[call-arg] return val @@ -498,8 +477,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -507,14 +488,16 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) row = await cursor.fetchone() if row is None: return None - val = next(iter(row)) + val = next(iter(row.values())) if row else None + if val is None: + return None if schema_type is not None: return schema_type(val) # type: ignore[call-arg] return val @@ -524,15 +507,17 @@ async def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> int: - """Insert, update, or delete data from the database. + """Execute an INSERT, UPDATE, or DELETE query and return the number of affected rows. Returns: - Row count affected by the operation. + The number of rows affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -547,8 +532,10 @@ async def insert_update_delete_returning( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -556,7 +543,7 @@ async def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -574,7 +561,9 @@ async def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -582,35 +571,8 @@ async def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) - return str(cursor.rowcount) - - async def execute_script_returning( - self, - sql: str, - parameters: "Optional[StatementParameterType]" = None, - /, - connection: "Optional[AsyncConnection]" = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - async with self._with_cursor(connection) as cursor: - await cursor.execute(sql, parameters) - result = await cursor.fetchone() - - if result is None: - return None - - if schema_type is not None: - return cast("ModelDTOT", schema_type(**result)) # pyright: ignore[reportUnknownArgumentType] - return cast("dict[str, Any]", result) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + return str(cursor.statusmessage) if cursor.statusmessage is not None else "DONE" diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 08e81f0a..7c5916ab 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -16,6 +16,7 @@ class SqliteDriver(SyncDriverAdapterProtocol["Connection"]): """SQLite Sync Driver Adapter.""" connection: "Connection" + dialect: str = "sqlite" def __init__(self, connection: "Connection") -> None: self.connection = connection @@ -37,8 +38,10 @@ def select( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -46,7 +49,7 @@ def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -65,8 +68,10 @@ def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -74,7 +79,7 @@ def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -92,8 +97,10 @@ def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -101,7 +108,7 @@ def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -120,8 +127,10 @@ def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -129,7 +138,7 @@ def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -146,8 +155,10 @@ def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -155,7 +166,7 @@ def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -173,7 +184,9 @@ def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -181,7 +194,7 @@ def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: @@ -195,8 +208,10 @@ def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -204,7 +219,7 @@ def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: @@ -214,46 +229,32 @@ def insert_update_delete_returning( result = cursor.fetchall() if len(result) == 0: return None - column_names = [c[0] for c in cursor.description or []] - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) - return dict(zip(column_names, result[0])) - - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - Converts named parameters (:name) to positional parameters (?) for SQLite. + # Get column names from cursor description + column_names = [c[0] for c in cursor.description or []] - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). + # Get the first row's values - ensure we're getting the actual values + row_values = result[0] - Returns: - A tuple containing the processed SQL string and the processed parameters. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters + # Debug print to see what we're getting - # Convert named parameters to positional parameters - processed_sql = sql - processed_params: list[Any] = [] - for key, value in parameters.items(): - # Replace :key with ? in the SQL - processed_sql = processed_sql.replace(f":{key}", "?") - processed_params.append(value) + # Create dictionary mapping column names to values + result_dict = {} + for i, col_name in enumerate(column_names): + result_dict[col_name] = row_values[i] - return processed_sql, tuple(processed_params) + if schema_type is not None: + return cast("ModelDTOT", schema_type(**result_dict)) + return result_dict def execute_script( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -261,25 +262,26 @@ def execute_script( Status message for the operation. """ connection = self._connection(connection) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - # For DDL statements, don't pass parameters to execute - # SQLite doesn't support parameters for DDL statements + # The _process_sql_params handles parameter formatting for the dialect. with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] else: - sql, parameters = self._process_sql_params(sql, parameters) - cursor.execute(sql, parameters) # type: ignore[arg-type] + cursor.execute(sql, parameters) - return cast("str", cursor.statusmessage) if hasattr(cursor, "statusmessage") else "DONE" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] + return cast("str", cursor.statusmessage) if hasattr(cursor, "statusmessage") else "DONE" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] def execute_script_returning( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Execute a script and return result. @@ -287,7 +289,7 @@ def execute_script_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: diff --git a/sqlspec/base.py b/sqlspec/base.py index 4e727950..855432b1 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -17,6 +17,7 @@ ) from sqlspec.exceptions import NotFoundError +from sqlspec.statement import SQLStatement from sqlspec.typing import ModelDTOT, StatementParameterType if TYPE_CHECKING: @@ -34,6 +35,7 @@ "NoPoolAsyncConfig", "NoPoolSyncConfig", "SQLSpec", + "SQLStatement", "SyncArrowBulkOperationsMixin", "SyncDatabaseConfig", "SyncDriverAdapterProtocol", @@ -50,13 +52,15 @@ bound="Union[Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]], SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]", ) DriverT = TypeVar("DriverT", bound="Union[SyncDriverAdapterProtocol[Any], AsyncDriverAdapterProtocol[Any]]") - -# Regex to find :param style placeholders, avoiding those inside quotes -# Handles basic cases, might need refinement for complex SQL +# Regex to find :param or %(param)s style placeholders, skipping those inside quotes PARAM_REGEX = re.compile( - r"(?P\"(?:[^\"]|\"\")*\")|" # Double-quoted strings - r"(?P'(?:[^']|'')*')|" # Single-quoted strings - r"(?P[^:]):(?P[a-zA-Z_][a-zA-Z0-9_]*)" # :param placeholder + r""" + (?P"([^"]|\\")*") | # Double-quoted strings + (?P'([^']|\\')*') | # Single-quoted strings + : (?P[a-zA-Z_][a-zA-Z0-9_]*) | # :var_name + % \( (?P[a-zA-Z_][a-zA-Z0-9_]*) \) s # %(var_name)s + """, + re.VERBOSE, ) @@ -415,8 +419,8 @@ def close_pool( class CommonDriverAttributes(Generic[ConnectionT]): """Common attributes and methods for driver adapters.""" - param_style: str = "?" - """The parameter style placeholder supported by the underlying database driver (e.g., '?', '%s').""" + dialect: str + """The SQL dialect supported by the underlying database driver (e.g., 'postgres', 'mysql').""" connection: ConnectionT """The connection to the underlying database.""" __supports_arrow__: ClassVar[bool] = False @@ -443,83 +447,23 @@ def check_not_found(item_or_none: Optional[T] = None) -> T: raise NotFoundError(msg) return item_or_none - def _process_sql_statement(self, sql: str) -> str: - """Perform any preprocessing of the SQL query string if needed. - Default implementation returns the SQL unchanged. - - Args: - sql: The SQL query string. - - Returns: - The processed SQL query string. - """ - return sql - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None + self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, **kwargs: Any ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters specified by `self.param_style` - if the input parameters are a dictionary. + """Process SQL query and parameters using SQLStatement for validation and formatting. Args: sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). + parameters: Parameters for the query. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: - A tuple containing the processed SQL string and the processed parameters - (always a tuple or None if the input was a dictionary, otherwise the original type). - - Raises: - ValueError: If a named parameter in the SQL is not found in the dictionary - or if a parameter in the dictionary is not used in the SQL. + A tuple containing the processed SQL query and parameters. """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return self._process_sql_statement(sql), parameters - - processed_sql = "" - processed_params_list: list[Any] = [] - last_end = 0 - found_params: set[str] = set() - - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - # Skip placeholders within quotes - continue - - var_name = match.group("var_name") - if var_name is None: # Should not happen with the regex, but safeguard - continue - - if var_name not in parameters: - msg = f"Named parameter ':{var_name}' found in SQL but not provided in parameters dictionary." - raise ValueError(msg) - - # Append segment before the placeholder + the leading character + the driver's positional placeholder - # The match.start("var_name") -1 includes the character before the ':' - processed_sql += sql[last_end : match.start("var_name")] + self.param_style - processed_params_list.append(parameters[var_name]) - found_params.add(var_name) - last_end = match.end("var_name") - - # Append the rest of the SQL string - processed_sql += sql[last_end:] - - # Check if all provided parameters were used - unused_params = set(parameters.keys()) - found_params - if unused_params: - msg = f"Parameters provided but not found in SQL: {unused_params}" - # Depending on desired strictness, this could be a warning or an error - # For now, let's raise an error for clarity - raise ValueError(msg) - - processed_params = tuple(processed_params_list) - # Pass the processed SQL through the driver-specific processor if needed - final_sql = self._process_sql_statement(processed_sql) - return final_sql, processed_params + # Instantiate SQLStatement with parameters and kwargs for internal merging + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + # Process uses the merged parameters internally + return stmt.process() class SyncArrowBulkOperationsMixin(Generic[ConnectionT]): @@ -527,16 +471,15 @@ class SyncArrowBulkOperationsMixin(Generic[ConnectionT]): __supports_arrow__: "ClassVar[bool]" = True - def __init__(self, connection: ConnectionT) -> None: - self.connection = connection - @abstractmethod def select_arrow( # pyright: ignore[reportUnknownParameterType] self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, + **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] """Execute a SQL query and return results as an Apache Arrow Table. @@ -544,6 +487,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: The SQL query string. parameters: Parameters for the query. connection: Optional connection override. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: An Apache Arrow Table containing the query results. @@ -554,7 +498,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] class SyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]): connection: "ConnectionT" - def __init__(self, connection: "ConnectionT") -> None: + def __init__(self, connection: "ConnectionT", **kwargs: Any) -> None: self.connection = connection @abstractmethod @@ -563,8 +507,10 @@ def select( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: Optional[type[ModelDTOT]] = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod @@ -573,8 +519,10 @@ def select_one( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": ... @abstractmethod @@ -583,8 +531,10 @@ def select_one_or_none( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod @@ -593,8 +543,10 @@ def select_value( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[T]] = None, + **kwargs: Any, ) -> "Union[Any, T]": ... @abstractmethod @@ -603,8 +555,10 @@ def select_value_or_none( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[T]] = None, + **kwargs: Any, ) -> "Optional[Union[Any, T]]": ... @abstractmethod @@ -613,7 +567,9 @@ def insert_update_delete( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, + **kwargs: Any, ) -> int: ... @abstractmethod @@ -622,8 +578,10 @@ def insert_update_delete_returning( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ... @abstractmethod @@ -632,7 +590,9 @@ def execute_script( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, + **kwargs: Any, ) -> str: ... @@ -647,7 +607,9 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, + **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] """Execute a SQL query and return results as an Apache Arrow Table. @@ -655,6 +617,7 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: The SQL query string. parameters: Parameters for the query. connection: Optional connection override. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: An Apache Arrow Table containing the query results. @@ -674,8 +637,10 @@ async def select( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod @@ -684,8 +649,10 @@ async def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": ... @abstractmethod @@ -694,8 +661,10 @@ async def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod @@ -704,8 +673,10 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[Any, T]": ... @abstractmethod @@ -714,8 +685,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[Any, T]]": ... @abstractmethod @@ -724,7 +697,9 @@ async def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, + **kwargs: Any, ) -> int: ... @abstractmethod @@ -733,8 +708,10 @@ async def insert_update_delete_returning( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ... @abstractmethod @@ -743,7 +720,9 @@ async def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, + **kwargs: Any, ) -> str: ... diff --git a/sqlspec/exceptions.py b/sqlspec/exceptions.py index e78eea62..73eac9e3 100644 --- a/sqlspec/exceptions.py +++ b/sqlspec/exceptions.py @@ -1,3 +1,5 @@ +from collections.abc import Generator +from contextlib import contextmanager from typing import Any, Optional __all__ = ( @@ -6,7 +8,9 @@ "MissingDependencyError", "MultipleResultsFoundError", "NotFoundError", + "ParameterStyleMismatchError", "RepositoryError", + "SQLParsingError", "SQLSpecError", "SerializationError", ) @@ -74,6 +78,20 @@ def __init__(self, message: Optional[str] = None) -> None: super().__init__(message) +class ParameterStyleMismatchError(SQLSpecError): + """Error when parameter style doesn't match SQL placeholder style. + + This exception is raised when there's a mismatch between the parameter type + (dictionary, tuple, etc.) and the placeholder style in the SQL query + (named, positional, etc.). + """ + + def __init__(self, message: Optional[str] = None) -> None: + if message is None: + message = "Parameter style mismatch: dictionary parameters provided but no named placeholders found in SQL." + super().__init__(message) + + class ImproperConfigurationError(SQLSpecError): """Improper Configuration error. @@ -99,3 +117,15 @@ class NotFoundError(RepositoryError): class MultipleResultsFoundError(RepositoryError): """A single database result was required but more than one were found.""" + + +@contextmanager +def wrap_exceptions(wrap_exceptions: bool = True) -> Generator[None, None, None]: + try: + yield + + except Exception as exc: + if wrap_exceptions is False: + raise + msg = "An error occurred during the operation." + raise RepositoryError(detail=msg) from exc diff --git a/sqlspec/statement.py b/sqlspec/statement.py new file mode 100644 index 00000000..a02896d5 --- /dev/null +++ b/sqlspec/statement.py @@ -0,0 +1,373 @@ +# ruff: noqa: RUF100, PLR6301, PLR0912, PLR0915, C901, PLR0911, PLR0914, N806 +import logging +import re +from dataclasses import dataclass +from functools import cached_property +from typing import ( + Any, + Optional, + Union, +) + +import sqlglot +from sqlglot import exp + +from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError +from sqlspec.typing import StatementParameterType + +__all__ = ("SQLStatement",) + +logger = logging.getLogger("sqlspec") + +# Regex to find :param style placeholders, skipping those inside quotes or SQL comments +# Adapted from previous version in psycopg adapter +PARAM_REGEX = re.compile( + r"""(?"(?:[^"]|"")*") | # Double-quoted strings (support SQL standard escaping "") + (?P'(?:[^']|'')*') | # Single-quoted strings (support SQL standard escaping '') + (?P--.*?\n|\/\*.*?\*\/) | # SQL comments (single line or multi-line) + : (?P[a-zA-Z_][a-zA-Z0-9_]*) # :var_name identifier + ) + """, + re.VERBOSE | re.DOTALL, +) + + +@dataclass() +class SQLStatement: + """An immutable representation of a SQL statement with its parameters. + + This class encapsulates the SQL statement and its parameters, providing + a clean interface for parameter binding and SQL statement formatting. + """ + + dialect: str + """The SQL dialect to use for parsing (e.g., 'postgres', 'mysql'). Defaults to 'postgres' if None.""" + sql: str + """The raw SQL statement.""" + parameters: Optional[StatementParameterType] = None + """The parameters for the SQL statement.""" + kwargs: Optional[dict[str, Any]] = None + """Keyword arguments passed for parameter binding.""" + + _merged_parameters: Optional[Union[StatementParameterType, dict[str, Any]]] = None + + def __post_init__(self) -> None: + """Merge parameters and kwargs after initialization.""" + merged_params = self.parameters + + if self.kwargs: + if merged_params is None: + merged_params = self.kwargs + elif isinstance(merged_params, dict): + # Merge kwargs into parameters dict, kwargs take precedence + merged_params = {**merged_params, **self.kwargs} + else: + # If parameters is sequence or scalar, kwargs replace it + # Consider adding a warning here if this behavior is surprising + merged_params = self.kwargs + + self._merged_parameters = merged_params + + def process(self) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Process the SQL statement and merged parameters for execution. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + + Raises: + SQLParsingError: If the SQL statement contains parameter placeholders, but no parameters were provided. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + """ + if self._merged_parameters is None: + # Validate that the SQL doesn't expect parameters if none were provided + # Parse ONLY if we need to validate + try: # Add try/except in case parsing fails even here + expression = self._parse_sql() + except SQLParsingError: + # If parsing fails, we can't validate, but maybe that's okay if no params were passed? + # Log a warning? For now, let the original error propagate if needed. + # Or, maybe assume it's okay if _merged_parameters is None? + # Let's re-raise for now, as unparsable SQL is usually bad. + logger.warning("SQL statement is unparsable: %s", self.sql) + return self.sql, None + if list(expression.find_all(exp.Parameter)): + msg = "SQL statement contains parameter placeholders, but no parameters were provided." + raise SQLParsingError(msg) + return self.sql, None + + if isinstance(self._merged_parameters, dict): + # Pass only the dict, parsing happens inside + return self._process_dict_params(self._merged_parameters) + + if isinstance(self._merged_parameters, (tuple, list)): + # Pass only the sequence, parsing happens inside if needed for validation + return self._process_sequence_params(self._merged_parameters) + + # Assume it's a single scalar value otherwise + # Pass only the value, parsing happens inside for validation + return self._process_scalar_param(self._merged_parameters) + + def _parse_sql(self) -> exp.Expression: + """Parse the SQL using sqlglot. + + Raises: + SQLParsingError: If the SQL statement cannot be parsed. + + Returns: + The parsed SQL expression. + """ + parse_dialect = self.dialect or "postgres" + try: + read_dialect = parse_dialect or None + return sqlglot.parse_one(self.sql, read=read_dialect) + except Exception as e: + # Ensure the original sqlglot error message is included + error_detail = str(e) + msg = f"Failed to parse SQL with dialect '{parse_dialect or 'auto-detected'}': {error_detail}\nSQL: {self.sql}" + raise SQLParsingError(msg) from e + + def _process_dict_params( + self, + parameter_dict: dict[str, Any], + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Processes dictionary parameters based on dialect capabilities. + + Raises: + ParameterStyleMismatchError: If the SQL statement contains unnamed placeholders (e.g., '?') in the SQL query. + SQLParsingError: If the SQL statement contains named parameters, but no parameters were provided. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + """ + # Attempt to parse with sqlglot first (for other dialects like postgres, mysql) + named_sql_params: Optional[list[exp.Parameter]] = None + unnamed_sql_params: Optional[list[exp.Parameter]] = None + sqlglot_parsed_ok = False + # --- Dialect-Specific Bypasses for Native Handling --- + if self.dialect == "sqlite": # Handles :name natively + return self.sql, parameter_dict + + # Add bypass for postgres handled by specific adapters (e.g., asyncpg) + if self.dialect == "postgres": + # The adapter (e.g., asyncpg) will handle :name -> $n conversion. + # SQLStatement just validates parameters against the original SQL here. + # Perform validation using regex if sqlglot parsing fails, otherwise use sqlglot. + try: + expression = self._parse_sql() + sql_params = list(expression.find_all(exp.Parameter)) + named_sql_params = [p for p in sql_params if p.name] + unnamed_sql_params = [p for p in sql_params if not p.name] + + if unnamed_sql_params: + msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot for postgres." + raise ParameterStyleMismatchError(msg) + + # Validate keys using sqlglot results + required_keys = {p.name for p in named_sql_params} + provided_keys = set(parameter_dict.keys()) + missing_keys = required_keys - provided_keys + if missing_keys: + msg = ( + f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}" + ) + raise SQLParsingError(msg) # noqa: TRY301 + # Allow extra keys + + except SQLParsingError as e: + logger.debug("SQLglot parsing failed for postgres dict params, attempting regex validation: %s", e) + # Regex validation fallback (without conversion) + postgres_found_params_regex: list[str] = [] + for match in PARAM_REGEX.finditer(self.sql): + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + if match.group("var_name"): + var_name = match.group("var_name") + postgres_found_params_regex.append(var_name) + if var_name not in parameter_dict: + msg = f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}" + raise SQLParsingError(msg) # noqa: B904 + + if not postgres_found_params_regex and parameter_dict: + msg = f"Dictionary parameters provided, but no named placeholders (:name) found via regex. SQL: {self.sql}" + raise ParameterStyleMismatchError(msg) # noqa: B904 + # Allow extra keys with regex check too + + # Return the *original* SQL and the processed dict for the adapter to handle + return self.sql, parameter_dict + + if self.dialect == "duckdb": # Handles $name natively (and :name via driver? Check driver docs) + # Bypass sqlglot/regex checks. Trust user SQL ($name or ?) + dict for DuckDB driver. + # We lose :name -> $name conversion *if* sqlglot parsing fails, but avoid errors on valid $name SQL. + return self.sql, parameter_dict + # --- End Bypasses --- + + try: + expression = self._parse_sql() + sql_params = list(expression.find_all(exp.Parameter)) + named_sql_params = [p for p in sql_params if p.name] + unnamed_sql_params = [p for p in sql_params if not p.name] + sqlglot_parsed_ok = True + logger.debug("SQLglot parsed dict params successfully for: %s", self.sql) + except SQLParsingError as e: + logger.debug("SQLglot parsing failed for dict params, attempting regex fallback: %s", e) + # Proceed using regex fallback below + + # Check for unnamed placeholders if parsing worked + if sqlglot_parsed_ok and unnamed_sql_params: + msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot." + raise ParameterStyleMismatchError(msg) + + # Determine if we need to use regex fallback + # Use fallback if: parsing failed OR (parsing worked BUT found no named params when a dict was provided) + use_regex_fallback = not sqlglot_parsed_ok or (not named_sql_params and parameter_dict) + + if use_regex_fallback: + # Regex fallback logic for :name -> self.param_style conversion + # ... (regex fallback code as implemented previously) ... + logger.debug("Using regex fallback for dict param processing: %s", self.sql) + # --- Regex Fallback Logic --- + regex_processed_sql_parts: list[str] = [] + ordered_params = [] + last_end = 0 + regex_found_params: list[str] = [] + + for match in PARAM_REGEX.finditer(self.sql): + # Skip matches that are comments or quoted strings + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): + var_name = match.group("var_name") + regex_found_params.append(var_name) + # Get start and end from the match object for the :var_name part + # The var_name group itself doesn't include the leading :, so adjust start. + start = match.start("var_name") - 1 + end = match.end("var_name") + + if var_name not in parameter_dict: + msg = ( + f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}" + ) + raise SQLParsingError(msg) + + regex_processed_sql_parts.extend((self.sql[last_end:start], self.param_style)) # Use target style + ordered_params.append(parameter_dict[var_name]) + last_end = end + + regex_processed_sql_parts.append(self.sql[last_end:]) + + # Validation with regex results + if not regex_found_params and parameter_dict: + msg = f"Dictionary parameters provided, but no named placeholders (e.g., :name) found via regex in the SQL query for dialect '{self.dialect}'. SQL: {self.sql}" + raise ParameterStyleMismatchError(msg) + + provided_keys = set(parameter_dict.keys()) + missing_keys = set(regex_found_params) - provided_keys # Should be caught above, but double check + if missing_keys: + msg = f"Named parameters found in SQL (via regex) but not provided: {missing_keys}. SQL: {self.sql}" + raise SQLParsingError(msg) + + extra_keys = provided_keys - set(regex_found_params) + if extra_keys: + # Allow extra keys + pass + + return "".join(regex_processed_sql_parts), tuple(ordered_params) + + # Sqlglot Logic (if parsing worked and found params) + # ... (sqlglot logic as implemented previously, including :name -> %s conversion) ... + logger.debug("Using sqlglot results for dict param processing: %s", self.sql) + + # Ensure named_sql_params is iterable, default to empty list if None (shouldn't happen ideally) + active_named_params = named_sql_params or [] + + if not active_named_params and not parameter_dict: + # No SQL params found by sqlglot, no provided params dict -> OK + return self.sql, () + + # Validation with sqlglot results + required_keys = {p.name for p in active_named_params} # Use active_named_params + provided_keys = set(parameter_dict.keys()) + + missing_keys = required_keys - provided_keys + if missing_keys: + msg = f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}" + raise SQLParsingError(msg) + + extra_keys = provided_keys - required_keys + if extra_keys: + pass # Allow extra keys + + # Note: DuckDB handled by bypass above if sqlglot fails. + # This block handles successful sqlglot parse for other dialects. + # We don't need the specific DuckDB $name conversion here anymore, + # as the bypass handles the native $name case. + # The general logic converts :name -> self.param_style for dialects like postgres. + # if self.dialect == "duckdb": ... (Removed specific block here) + + # For other dialects requiring positional conversion (using sqlglot param info): + sqlglot_processed_parts: list[str] = [] + ordered_params = [] + last_end = 0 + for param in active_named_params: # Use active_named_params + start = param.this.this.start + end = param.this.this.end + sqlglot_processed_parts.extend((self.sql[last_end:start], self.param_style)) + ordered_params.append(parameter_dict[param.name]) + last_end = end + sqlglot_processed_parts.append(self.sql[last_end:]) + return "".join(sqlglot_processed_parts), tuple(ordered_params) + + def _process_sequence_params( + self, params: Union[tuple[Any, ...], list[Any]] + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Processes a sequence of parameters. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + """ + return self.sql, params + + def _process_scalar_param( + self, param_value: Any + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Processes a single scalar parameter value. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + """ + return self.sql, (param_value,) + + @cached_property + def param_style(self) -> str: + """Get the parameter style based on the dialect. + + Returns: + The parameter style placeholder for the dialect. + """ + dialect = self.dialect + + # Map dialects to parameter styles for placeholder replacement + # Note: Used when converting named params (:name) for dialects needing positional. + # Dialects supporting named params natively (SQLite, DuckDB) are handled via bypasses. + dialect_to_param_style = { + "postgres": "%s", + "mysql": "%s", + "oracle": ":1", + "mssql": "?", + "bigquery": "?", + "snowflake": "?", + "cockroach": "%s", + "db2": "?", + } + # Default to '?' for unknown/unhandled dialects or when dialect=None is forced + return dialect_to_param_style.get(dialect, "?") diff --git a/sqlspec/typing.py b/sqlspec/typing.py index bb2a1cf8..ee0aae7a 100644 --- a/sqlspec/typing.py +++ b/sqlspec/typing.py @@ -79,7 +79,7 @@ - :class:`DTOData`[:type:`list[ModelT]`] """ -StatementParameterType: TypeAlias = "Union[dict[str, Any], list[Any], tuple[Any, ...], None]" +StatementParameterType: TypeAlias = "Union[Any, dict[str, Any], list[Any], tuple[Any, ...], None]" """Type alias for parameter types. Represents: diff --git a/tests/fixtures/sql_utils.py b/tests/fixtures/sql_utils.py index ca69d526..b548b2db 100644 --- a/tests/fixtures/sql_utils.py +++ b/tests/fixtures/sql_utils.py @@ -13,11 +13,13 @@ def format_placeholder(field_name: str, style: str, dialect: Optional[str] = Non The formatted placeholder string. """ if style == "tuple_binds": - if dialect in ["sqlite", "duckdb", "aiosqlite"]: + if dialect in {"sqlite", "duckdb", "aiosqlite"}: return "?" # Default to Postgres/BigQuery style return "%s" - if dialect in ["sqlite", "duckdb", "aiosqlite"]: + if dialect == "duckdb": + return f"${field_name}" + if dialect in {"sqlite", "aiosqlite"}: return f":{field_name}" # For postgres and similar return f"%({field_name})s" diff --git a/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py b/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py index 23912f19..43bd56ec 100644 --- a/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py +++ b/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py @@ -276,3 +276,169 @@ def test_driver_select_arrow(adbc_session: AdbcConfig) -> None: assert arrow_table.column("id").to_pylist() == [1] driver.execute_script("DROP TABLE IF EXISTS test_table") driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_named_params_with_scalar(adbc_session: AdbcConfig) -> None: + """Test that scalar parameters work with named parameters in SQL.""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50) + ); + """ + driver.execute_script(sql) + + # Insert test record using positional parameter with scalar value + insert_sql = """ + INSERT INTO test_table (name) + VALUES (?) + """ + driver.insert_update_delete(insert_sql, "test_name") + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE name = ?" + results = driver.select(select_sql, "test_name") + assert len(results) == 1 + assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_named_params_with_tuple(adbc_session: AdbcConfig) -> None: + """Test that tuple parameters work with named parameters in SQL.""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50), + age INTEGER + ); + """ + driver.execute_script(sql) + + # Insert test record using positional parameters with tuple values + insert_sql = """ + INSERT INTO test_table (name, age) + VALUES (?, ?) + """ + driver.insert_update_delete(insert_sql, ("test_name", 30)) + + # Select and verify + select_sql = "SELECT name, age FROM test_table WHERE name = ? AND age = ?" + results = driver.select(select_sql, ("test_name", 30)) + assert len(results) == 1 + assert results[0]["name"] == "test_name" + assert results[0]["age"] == 30 + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_native_named_params(adbc_session: AdbcConfig) -> None: + """Test DuckDB's native named parameter style ($name).""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50) + ); + """ + driver.execute_script(sql) + + # Insert test record using native $name style + insert_sql = """ + INSERT INTO test_table (name) + VALUES ($name) + """ + driver.insert_update_delete(insert_sql, {"name": "native_name"}) + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE name = $name" + results = driver.select(select_sql, {"name": "native_name"}) + assert len(results) == 1 + assert results[0]["name"] == "native_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_native_positional_params(adbc_session: AdbcConfig) -> None: + """Test DuckDB's native positional parameter style ($1, $2, etc.).""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50), + age INTEGER + ); + """ + driver.execute_script(sql) + + # Insert test record using native $1 style + insert_sql = """ + INSERT INTO test_table (name, age) + VALUES ($1, $2) + """ + driver.insert_update_delete(insert_sql, ("native_pos", 30)) + + # Select and verify + select_sql = "SELECT name, age FROM test_table WHERE name = $1 AND age = $2" + results = driver.select(select_sql, ("native_pos", 30)) + assert len(results) == 1 + assert results[0]["name"] == "native_pos" + assert results[0]["age"] == 30 + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_native_auto_incremented_params(adbc_session: AdbcConfig) -> None: + """Test DuckDB's native auto-incremented parameter style (?).""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50), + age INTEGER + ); + """ + driver.execute_script(sql) + + # Insert test record using native ? style + insert_sql = """ + INSERT INTO test_table (name, age) + VALUES (?, ?) + """ + driver.insert_update_delete(insert_sql, ("native_auto", 35)) + + # Select and verify + select_sql = "SELECT name, age FROM test_table WHERE name = ? AND age = ?" + results = driver.select(select_sql, ("native_auto", 35)) + assert len(results) == 1 + assert results[0]["name"] == "native_auto" + assert results[0]["age"] == 35 + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") diff --git a/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py b/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py index fa7c5ea7..8e85381d 100644 --- a/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py +++ b/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py @@ -250,3 +250,61 @@ def test_driver_select_arrow(adbc_session: AdbcConfig) -> None: # Assuming id is 1 for the inserted record assert arrow_table.column("id").to_pylist() == [1] driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("sqlite") +def test_driver_named_params_with_scalar(adbc_session: AdbcConfig) -> None: + """Test that scalar parameters work with named parameters in SQL.""" + with adbc_session.provide_session() as driver: + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(50) + ); + """ + driver.execute_script(sql) + + # Insert test record using named parameter with scalar value + insert_sql = """ + INSERT INTO test_table (name) + VALUES (:name) + """ + driver.insert_update_delete(insert_sql, "test_name") + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE name = :name" + results = driver.select(select_sql, "test_name") + assert len(results) == 1 + assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("sqlite") +def test_driver_named_params_with_tuple(adbc_session: AdbcConfig) -> None: + """Test that tuple parameters work with named parameters in SQL.""" + with adbc_session.provide_session() as driver: + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(50), + age INTEGER + ); + """ + driver.execute_script(sql) + + # Insert test record using named parameters with tuple values + insert_sql = """ + INSERT INTO test_table (name, age) + VALUES (:name, :age) + """ + driver.insert_update_delete(insert_sql, ("test_name", 30)) + + # Select and verify + select_sql = "SELECT name, age FROM test_table WHERE name = :name AND age = :age" + results = driver.select(select_sql, ("test_name", 30)) + assert len(results) == 1 + assert results[0]["name"] == "test_name" + assert results[0]["age"] == 30 + driver.execute_script("DROP TABLE IF EXISTS test_table") diff --git a/tests/integration/test_adapters/test_asyncpg/__init__.py b/tests/integration/test_adapters/test_asyncpg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_adapters/test_asyncpg/test_connection.py b/tests/integration/test_adapters/test_asyncpg/test_connection.py new file mode 100644 index 00000000..31ba7017 --- /dev/null +++ b/tests/integration/test_adapters/test_asyncpg/test_connection.py @@ -0,0 +1,42 @@ +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.asyncpg import AsyncpgConfig, AsyncpgPoolConfig + + +@pytest.mark.xdist_group("postgres") +async def test_async_connection(postgres_service: PostgresService) -> None: + """Test asyncpg connection components.""" + # Test direct connection + async_config = AsyncpgConfig( + pool_config=AsyncpgPoolConfig( + dsn=f"postgres://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", + ), + ) + + conn = await async_config.create_connection() + try: + assert conn is not None + # Test basic query + result = await conn.fetchval("SELECT 1") + assert result == 1 + finally: + await conn.close() + + # Test connection pool + pool_config = AsyncpgPoolConfig( + dsn=f"postgres://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", + min_size=1, + max_size=5, + ) + another_config = AsyncpgConfig(pool_config=pool_config) + # Ensure the pool is created before use if not explicitly managed elsewhere + await another_config.create_pool() + try: + async with another_config.provide_connection() as conn: + assert conn is not None + # Test basic query + result = await conn.fetchval("SELECT 1") + assert result == 1 + finally: + await another_config.close_pool() diff --git a/tests/integration/test_adapters/test_asyncpg/test_driver.py b/tests/integration/test_adapters/test_asyncpg/test_driver.py new file mode 100644 index 00000000..4e62cedf --- /dev/null +++ b/tests/integration/test_adapters/test_asyncpg/test_driver.py @@ -0,0 +1,275 @@ +"""Test Asyncpg driver implementation.""" + +from __future__ import annotations + +from typing import Any, Literal + +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.asyncpg import AsyncpgConfig, AsyncpgPoolConfig + +ParamStyle = Literal["tuple_binds", "dict_binds"] + + +@pytest.fixture +def asyncpg_config(postgres_service: PostgresService) -> AsyncpgConfig: + """Create an Asyncpg configuration. + + Args: + postgres_service: PostgreSQL service fixture. + + Returns: + Configured Asyncpg session config. + """ + return AsyncpgConfig( + pool_config=AsyncpgPoolConfig( + dsn=f"postgres://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", + min_size=1, # Add min_size to avoid pool deadlock issues in tests + max_size=5, + ) + ) + + +@pytest.mark.parametrize( + ("params", "style"), + [ + pytest.param(("test_name",), "tuple_binds", id="tuple_binds"), + pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), + ], +) +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_async_insert_returning(asyncpg_config: AsyncpgConfig, params: Any, style: ParamStyle) -> None: + """Test async insert returning functionality with different parameter styles.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ); + """ + await driver.execute_script(sql) + + # Use appropriate SQL for each style (sqlspec driver handles conversion to $1, $2...) + if style == "tuple_binds": + sql = """ + INSERT INTO test_table (name) + VALUES (?) + RETURNING * + """ + else: # dict_binds + sql = """ + INSERT INTO test_table (name) + VALUES (:name) + RETURNING * + """ + + try: + result = await driver.insert_update_delete_returning(sql, params) + assert result is not None + assert result["name"] == "test_name" + assert result["id"] is not None + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@pytest.mark.parametrize( + ("params", "style"), + [ + pytest.param(("test_name",), "tuple_binds", id="tuple_binds"), + pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), + ], +) +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_async_select(asyncpg_config: AsyncpgConfig, params: Any, style: ParamStyle) -> None: + """Test async select functionality with different parameter styles.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + # Create test table + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ); + """ + await driver.execute_script(sql) + + # Insert test record + if style == "tuple_binds": + insert_sql = """ + INSERT INTO test_table (name) + VALUES (?) + """ + else: # dict_binds + insert_sql = """ + INSERT INTO test_table (name) + VALUES (:name) + """ + await driver.insert_update_delete(insert_sql, params) + + # Select and verify + if style == "tuple_binds": + select_sql = """ + SELECT name FROM test_table WHERE name = ? + """ + else: # dict_binds + select_sql = """ + SELECT name FROM test_table WHERE name = :name + """ + try: + results = await driver.select(select_sql, params) + assert len(results) == 1 + assert results[0]["name"] == "test_name" + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@pytest.mark.parametrize( + ("params", "style"), + [ + pytest.param(("test_name",), "tuple_binds", id="tuple_binds"), + pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), + ], +) +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_async_select_value(asyncpg_config: AsyncpgConfig, params: Any, style: ParamStyle) -> None: + """Test async select_value functionality with different parameter styles.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + # Create test table + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ); + """ + await driver.execute_script(sql) + + # Insert test record + if style == "tuple_binds": + insert_sql = """ + INSERT INTO test_table (name) + VALUES (?) + """ + else: # dict_binds + insert_sql = """ + INSERT INTO test_table (name) + VALUES (:name) + """ + await driver.insert_update_delete(insert_sql, params) + + # Get literal string to test with select_value + # Use a literal query to test select_value + select_sql = "SELECT 'test_name' AS test_name" + + try: + # Don't pass parameters with a literal query that has no placeholders + value = await driver.select_value(select_sql) + assert value == "test_name" + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_insert(asyncpg_config: AsyncpgConfig) -> None: + """Test inserting data.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ) + """ + await driver.execute_script(sql) + + insert_sql = "INSERT INTO test_table (name) VALUES (?)" + try: + row_count = await driver.insert_update_delete(insert_sql, ("test",)) + assert row_count == 1 + + # Verify insertion + select_sql = "SELECT COUNT(*) FROM test_table WHERE name = ?" + count = await driver.select_value(select_sql, ("test",)) + assert count == 1 + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_select(asyncpg_config: AsyncpgConfig) -> None: + """Test selecting data.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + # Create and populate test table + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ) + """ + await driver.execute_script(sql) + + insert_sql = "INSERT INTO test_table (name) VALUES (?)" + await driver.insert_update_delete(insert_sql, ("test",)) + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE id = ?" + try: + results = await driver.select(select_sql, (1,)) + assert len(results) == 1 + assert results[0]["name"] == "test" + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +# Asyncpg uses positional ($n) parameters internally. +# The sqlspec driver converts '?' (tuple) and ':name' (dict) styles. +# We test these two styles as they are what the user interacts with via sqlspec. +@pytest.mark.parametrize( + "param_style", + [ + "tuple_binds", # Corresponds to '?' in SQL passed to sqlspec + "dict_binds", # Corresponds to ':name' in SQL passed to sqlspec + ], +) +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_param_styles(asyncpg_config: AsyncpgConfig, param_style: str) -> None: + """Test different parameter styles expected by sqlspec.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + # Create test table + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ) + """ + await driver.execute_script(sql) + + # Insert test record based on param style + if param_style == "tuple_binds": + insert_sql = "INSERT INTO test_table (name) VALUES (?)" + params: Any = ("test",) + else: # dict_binds + insert_sql = "INSERT INTO test_table (name) VALUES (:name)" + params = {"name": "test"} + + try: + row_count = await driver.insert_update_delete(insert_sql, params) + assert row_count == 1 + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE id = ?" + results = await driver.select(select_sql, (1,)) + assert len(results) == 1 + assert results[0]["name"] == "test" + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table")